diff --git a/doc/changes/latest.inc b/doc/changes/latest.inc index e5ecd3192de..97069ac3fef 100644 --- a/doc/changes/latest.inc +++ b/doc/changes/latest.inc @@ -97,6 +97,7 @@ Changelog - Add ``border`` argument to :func:`mne.viz.plot_topomap`. ``border`` controls the value of the edge points to which topomap values are extrapolated. ``border='mean'`` sets these points value to the average of their neighbours. By `Mikołaj Magnuski`_ +- Add function :func:`mne.viz.link_brains` to link time properties of multiple brain objects interactively by `Guillaume Favelier`_ Bug ~~~ diff --git a/doc/python_reference.rst b/doc/python_reference.rst index 96b9b77a492..e837629e284 100644 --- a/doc/python_reference.rst +++ b/doc/python_reference.rst @@ -261,6 +261,7 @@ Visualization plot_sensors_connectivity plot_snr_estimate plot_source_estimates + link_brains plot_volume_source_estimates plot_vector_source_estimates plot_sparse_source_estimates diff --git a/mne/viz/_3d.py b/mne/viz/_3d.py index 68b4e7772eb..cff65d52d7c 100644 --- a/mne/viz/_3d.py +++ b/mne/viz/_3d.py @@ -16,6 +16,7 @@ from itertools import cycle import os.path as op import warnings +import collections from functools import partial import numpy as np @@ -1536,6 +1537,35 @@ def _plot_mpl_stc(stc, subject=None, surface='inflated', hemi='lh', return fig +def link_brains(brains): + """Plot multiple SourceEstimate objects with PyVista. + + Parameters + ---------- + brains : list, tuple or np.ndarray + The collection of brains to plot. + """ + from .backends.renderer import get_3d_backend + if get_3d_backend() != 'pyvista': + raise NotImplementedError("Expected 3d backend is pyvista but" + " {} was given.".format(get_3d_backend())) + from ._brain import _Brain, _TimeViewer, _LinkViewer + if not isinstance(brains, collections.Iterable): + brains = [brains] + if len(brains) == 0: + raise ValueError("The collection of brains is empty.") + for brain in brains: + if isinstance(brain, _Brain): + # check if the _TimeViewer wrapping is not already applied + if not hasattr(brain, 'time_viewer') or brain.time_viewer is None: + brain = _TimeViewer(brain) + else: + raise TypeError("Expected type is Brain but" + " {} was given.".format(type(brain))) + # link brains properties + _LinkViewer(brains) + + @verbose def plot_source_estimates(stc, subject=None, surface='inflated', hemi='lh', colormap='auto', time_label='auto', diff --git a/mne/viz/__init__.py b/mne/viz/__init__.py index 60b7b7f149f..34ef3b71d1a 100644 --- a/mne/viz/__init__.py +++ b/mne/viz/__init__.py @@ -10,7 +10,8 @@ plot_vector_source_estimates, plot_evoked_field, plot_dipole_locations, snapshot_brain_montage, plot_head_positions, plot_alignment, plot_brain_colorbar, - plot_volume_source_estimates, plot_sensors_connectivity) + plot_volume_source_estimates, plot_sensors_connectivity, + link_brains) from .misc import (plot_cov, plot_csd, plot_bem, plot_events, plot_source_spectrogram, _get_presser, plot_dipole_amplitudes, plot_ideal_filter, plot_filter, diff --git a/mne/viz/_brain/__init__.py b/mne/viz/_brain/__init__.py index e9e60e25e39..79842e6e5c7 100644 --- a/mne/viz/_brain/__init__.py +++ b/mne/viz/_brain/__init__.py @@ -10,6 +10,6 @@ # License: Simplified BSD from ._brain import _Brain -from ._timeviewer import _TimeViewer +from ._timeviewer import _TimeViewer, _LinkViewer __all__ = ['_Brain'] diff --git a/mne/viz/_brain/_timeviewer.py b/mne/viz/_brain/_timeviewer.py index e3d15d83d35..b2142b98cd8 100644 --- a/mne/viz/_brain/_timeviewer.py +++ b/mne/viz/_brain/_timeviewer.py @@ -15,16 +15,19 @@ def __init__(self, plotter=None, callback=None, name=None): self.plotter = plotter self.callback = callback self.name = name + self.slider_rep = None def __call__(self, value): """Round the label of the slider.""" idx = int(round(value)) - for slider in self.plotter.slider_widgets: - name = getattr(slider, "name", None) - if name == self.name: - slider_rep = slider.GetRepresentation() - slider_rep.SetValue(idx) - self.callback(idx) + if self.slider_rep is None: + for slider in self.plotter.slider_widgets: + name = getattr(slider, "name", None) + if name == self.name: + self.slider_rep = slider.GetRepresentation() + if self.slider_rep is not None: + self.slider_rep.SetValue(idx) + self.callback(idx) class UpdateColorbarScale(object): @@ -33,6 +36,7 @@ class UpdateColorbarScale(object): def __init__(self, plotter=None, brain=None): self.plotter = plotter self.brain = brain + self.slider_rep = None def __call__(self, value): """Update the colorbar sliders.""" @@ -121,6 +125,7 @@ def __init__(self, plotter=None, brain=None, orientation=None, self.col = col self.hemi = hemi self.name = name + self.slider_rep = None def __call__(self, value, update_widget=False): """Update the view.""" @@ -131,12 +136,40 @@ def __call__(self, value, update_widget=False): idx = self.orientation.index(value) else: idx = self.short_orientation.index(value) - for slider in self.plotter.slider_widgets: - name = getattr(slider, "name", None) - if name == self.name: - slider_rep = slider.GetRepresentation() - slider_rep.SetValue(idx) - slider_rep.SetTitleText(self.orientation[idx]) + if self.slider_rep is None: + for slider in self.plotter.slider_widgets: + name = getattr(slider, "name", None) + if name == self.name: + self.slider_rep = slider.GetRepresentation() + if self.slider_rep is not None: + self.slider_rep.SetValue(idx) + self.slider_rep.SetTitleText(self.orientation[idx]) + + +class SmartSlider(object): + """Class to manage smart slider. + + It stores it's own slider representation for efficiency + and uses it when necessary. + """ + + def __init__(self, plotter=None, callback=None, name=None): + self.plotter = plotter + self.callback = callback + self.name = name + self.slider_rep = None + + def __call__(self, value, update_widget=False): + """Update the value.""" + self.callback(value) + if update_widget: + if self.slider_rep is None: + for slider in self.plotter.slider_widgets: + name = getattr(slider, "name", None) + if name == self.name: + self.slider_rep = slider.GetRepresentation() + if self.slider_rep is not None: + self.slider_rep.SetValue(value) class _TimeViewer(object): @@ -144,6 +177,7 @@ class _TimeViewer(object): def __init__(self, brain): self.brain = brain + self.brain.time_viewer = self self.plotter = brain._renderer.plotter # orientation slider @@ -170,7 +204,7 @@ def __init__(self, brain): for ri, view in enumerate(self.brain._views): self.plotter.subplot(ri, ci) name = "orientation_" + str(ri) + "_" + str(ci) - self.show_view = ShowView( + self.orientation_call = ShowView( plotter=self.plotter, brain=self.brain, orientation=orientation, @@ -180,7 +214,7 @@ def __init__(self, brain): name=name ) orientation_slider = self.plotter.add_text_slider_widget( - self.show_view, + self.orientation_call, value=0, data=orientation, pointa=(0.82, 0.74), @@ -189,7 +223,7 @@ def __init__(self, brain): ) orientation_slider.name = name self.set_slider_style(orientation_slider, show_label=False) - self.show_view(view, update_widget=True) + self.orientation_call(view, update_widget=True) # necessary because show_view modified subplot if self.brain._hemi == 'split': @@ -205,20 +239,20 @@ def __init__(self, brain): # smoothing slider default_smoothing_value = 7 - self.set_smoothing = IntSlider( + self.smoothing_call = IntSlider( plotter=self.plotter, callback=brain.set_data_smoothing, name="smoothing" ) smoothing_slider = self.plotter.add_slider_widget( - self.set_smoothing, + self.smoothing_call, value=default_smoothing_value, rng=[0, 15], title="smoothing", pointa=(0.82, 0.90), pointb=(0.98, 0.90) ) smoothing_slider.name = 'smoothing' - self.set_smoothing(default_smoothing_value) + self.smoothing_call(default_smoothing_value) # time label self.time_actor = brain._data.get('time_actor') @@ -228,25 +262,37 @@ def __init__(self, brain): # time slider max_time = len(brain._data['time']) - 1 + self.time_call = SmartSlider( + plotter=self.plotter, + callback=self.brain.set_time_point, + name="time" + ) time_slider = self.plotter.add_slider_widget( - brain.set_time_point, + self.time_call, value=brain._data['time_idx'], rng=[0, max_time], pointa=(0.23, 0.1), pointb=(0.77, 0.1), event_type='always' ) - time_slider.name = "time_slider" + time_slider.name = "time" # playback speed default_playback_speed = 0.05 + self.playback_speed_call = SmartSlider( + plotter=self.plotter, + callback=self.set_playback_speed, + name="playback_speed" + ) playback_speed_slider = self.plotter.add_slider_widget( - self.set_playback_speed, + self.playback_speed_call, value=default_playback_speed, - rng=[0.01, 1], title="playback speed", + rng=[0.01, 1], title="speed", pointa=(0.02, 0.1), - pointb=(0.18, 0.1) + pointb=(0.18, 0.1), + event_type='always' ) + playback_speed_slider.name = "playback_speed" # colormap slider scaling_limits = [0.2, 2.0] @@ -254,13 +300,13 @@ def __init__(self, brain): pointb = np.array((0.98, 0.26)) shift = np.array([0, 0.08]) fmin = brain._data["fmin"] - self.update_fmin = BumpColorbarPoints( + self.fmin_call = BumpColorbarPoints( plotter=self.plotter, brain=brain, name="fmin" ) fmin_slider = self.plotter.add_slider_widget( - self.update_fmin, + self.fmin_call, value=fmin, rng=_get_range(brain), title="clim", pointa=pointa, @@ -268,14 +314,15 @@ def __init__(self, brain): event_type="always", ) fmin_slider.name = "fmin" + self.fmin_slider_rep = fmin_slider.GetRepresentation() fmid = brain._data["fmid"] - self.update_fmid = BumpColorbarPoints( + self.fmid_call = BumpColorbarPoints( plotter=self.plotter, brain=brain, name="fmid", ) fmid_slider = self.plotter.add_slider_widget( - self.update_fmid, + self.fmid_call, value=fmid, rng=_get_range(brain), title="", pointa=pointa + shift, @@ -283,14 +330,15 @@ def __init__(self, brain): event_type="always", ) fmid_slider.name = "fmid" + self.fmid_slider_rep = fmid_slider.GetRepresentation() fmax = brain._data["fmax"] - self.update_fmax = BumpColorbarPoints( + self.fmax_call = BumpColorbarPoints( plotter=self.plotter, brain=brain, name="fmax", ) fmax_slider = self.plotter.add_slider_widget( - self.update_fmax, + self.fmax_call, value=fmax, rng=_get_range(brain), title="", pointa=pointa + 2 * shift, @@ -298,12 +346,13 @@ def __init__(self, brain): event_type="always", ) fmax_slider.name = "fmax" - self.update_fscale = UpdateColorbarScale( + self.fmax_slider_rep = fmax_slider.GetRepresentation() + self.fscale_call = UpdateColorbarScale( plotter=self.plotter, brain=brain, ) fscale_slider = self.plotter.add_slider_widget( - self.update_fscale, + self.fscale_call, value=1.0, rng=scaling_limits, title="fscale", pointa=(0.82, 0.10), @@ -311,7 +360,7 @@ def __init__(self, brain): ) fscale_slider.name = "fscale" - # add toggle to start/stop playback + # add toggle to start/pause playback self.playback = False self.playback_speed = default_playback_speed self.refresh_rate_ms = max(int(round(1000. / 60.)), 1) @@ -343,44 +392,23 @@ def __init__(self, brain): def toggle_interface(self): self.visibility = not self.visibility for slider in self.plotter.slider_widgets: + slider_rep = slider.GetRepresentation() if self.visibility: - slider.On() + slider_rep.VisibilityOn() else: - slider.Off() + slider_rep.VisibilityOff() def apply_auto_scaling(self): self.brain.update_auto_scaling() - fmin = self.brain._data['fmin'] - fmid = self.brain._data['fmid'] - fmax = self.brain._data['fmax'] - for slider in self.plotter.slider_widgets: - name = getattr(slider, "name", None) - if name == "fmin": - slider_rep = slider.GetRepresentation() - slider_rep.SetValue(fmin) - elif name == "fmid": - slider_rep = slider.GetRepresentation() - slider_rep.SetValue(fmid) - elif name == "fmax": - slider_rep = slider.GetRepresentation() - slider_rep.SetValue(fmax) + self.fmin_slider_rep.SetValue(self.brain._data['fmin']) + self.fmid_slider_rep.SetValue(self.brain._data['fmid']) + self.fmax_slider_rep.SetValue(self.brain._data['fmax']) def restore_user_scaling(self): self.brain.update_auto_scaling(restore=True) - fmin = self.brain._data['fmin'] - fmid = self.brain._data['fmid'] - fmax = self.brain._data['fmax'] - for slider in self.plotter.slider_widgets: - name = getattr(slider, "name", None) - if name == "fmin": - slider_rep = slider.GetRepresentation() - slider_rep.SetValue(fmin) - elif name == "fmid": - slider_rep = slider.GetRepresentation() - slider_rep.SetValue(fmid) - elif name == "fmax": - slider_rep = slider.GetRepresentation() - slider_rep.SetValue(fmax) + self.fmin_slider_rep.SetValue(self.brain._data['fmin']) + self.fmid_slider_rep.SetValue(self.brain._data['fmid']) + self.fmax_slider_rep.SetValue(self.brain._data['fmax']) def toggle_playback(self): self.playback = not self.playback @@ -407,12 +435,7 @@ def play(self): time_point = min(self.brain._current_time + time_shift, max_time) ifunc = interp1d(time_data, times) idx = ifunc(time_point) - self.brain.set_time_point(idx) - for slider in self.plotter.slider_widgets: - name = getattr(slider, "name", None) - if name == "time_slider": - slider_rep = slider.GetRepresentation() - slider_rep.SetValue(idx) + self.time_call(idx, update_widget=True) if time_point == max_time: self.playback = False self.plotter.update() # critical for smooth animation @@ -430,6 +453,59 @@ def set_slider_style(self, slider, show_label=True): slider_rep.ShowSliderLabelOff() +class _LinkViewer(object): + """Class to link multiple _TimeViewer objects.""" + + def __init__(self, brains): + self.brains = brains + self.time_viewers = [brain.time_viewer for brain in brains] + + # link time sliders + self.link_sliders( + name="time", + callback=self.set_time_point, + event_type="always" + ) + + # link playback speed sliders + self.link_sliders( + name="playback_speed", + callback=self.set_playback_speed, + event_type="always" + ) + + # link toggle to start/pause playback + for time_viewer in self.time_viewers: + plotter = time_viewer.plotter + plotter.clear_events_for_key('space') + plotter.add_key_event('space', self.toggle_playback) + + def set_time_point(self, value): + for time_viewer in self.time_viewers: + time_viewer.time_call(value, update_widget=True) + + def set_playback_speed(self, value): + for time_viewer in self.time_viewers: + time_viewer.playback_speed_call(value, update_widget=True) + + def toggle_playback(self): + for time_viewer in self.time_viewers: + time_viewer.toggle_playback() + + def link_sliders(self, name, callback, event_type): + from ..backends._pyvista import _update_slider_callback + for time_viewer in self.time_viewers: + plotter = time_viewer.plotter + for slider in plotter.slider_widgets: + slider_name = getattr(slider, "name", None) + if slider_name == name: + _update_slider_callback( + slider=slider, + callback=callback, + event_type=event_type + ) + + def _set_text_style(text_actor): if text_actor is not None: prop = text_actor.GetTextProperty() diff --git a/mne/viz/_brain/tests/test_brain.py b/mne/viz/_brain/tests/test_brain.py index 5044c523340..8f6e1aa76bd 100644 --- a/mne/viz/_brain/tests/test_brain.py +++ b/mne/viz/_brain/tests/test_brain.py @@ -17,7 +17,7 @@ from mne import SourceEstimate, read_source_estimate from mne.source_space import read_source_spaces from mne.datasets import testing -from mne.viz._brain import _Brain, _TimeViewer +from mne.viz._brain import _Brain, _TimeViewer, _LinkViewer from mne.viz._brain.colormap import calculate_lut from matplotlib import cm @@ -170,22 +170,27 @@ def test_brain_timeviewer(renderer): colormap='hot', vertices=vertices, colorbar=True) - brain_data.set_time_point(time_idx=0) - time_viewer = _TimeViewer(brain_data) - time_viewer.show_view('lat', update_widget=True) - time_viewer.show_view('medial', update_widget=True) - time_viewer.set_smoothing(value=1) - time_viewer.update_fmin(value=12.0) - time_viewer.update_fmax(value=4.0) - time_viewer.update_fmid(value=6.0) - time_viewer.update_fmid(value=4.0) - time_viewer.update_fscale(value=1.1) + time_viewer.time_call(value=0) + time_viewer.orientation_call(value='lat', update_widget=True) + time_viewer.orientation_call(value='medial', update_widget=True) + time_viewer.smoothing_call(value=1) + time_viewer.fmin_call(value=12.0) + time_viewer.fmax_call(value=4.0) + time_viewer.fmid_call(value=6.0) + time_viewer.fmid_call(value=4.0) + time_viewer.fscale_call(value=1.1) time_viewer.toggle_interface() + time_viewer.playback_speed_call(value=0.1) time_viewer.toggle_playback() time_viewer.apply_auto_scaling() time_viewer.restore_user_scaling() + link_viewer = _LinkViewer([brain_data]) + link_viewer.set_time_point(value=0) + link_viewer.set_playback_speed(value=0.1) + link_viewer.toggle_playback() + def test_brain_colormap(): """Test brain's colormap functions.""" diff --git a/mne/viz/backends/_pyvista.py b/mne/viz/backends/_pyvista.py index 91456a4213f..b6a32a73873 100644 --- a/mne/viz/backends/_pyvista.py +++ b/mne/viz/backends/_pyvista.py @@ -604,3 +604,23 @@ def _set_mesh_scalars(mesh, scalars, name): with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=FutureWarning) mesh.point_arrays[name] = scalars + + +def _update_slider_callback(slider, callback, event_type): + from pyvista.utilities import try_callback + + def _the_callback(widget, event): + value = widget.GetRepresentation().GetValue() + if hasattr(callback, '__call__'): + try_callback(callback, value) + return + + if event_type == 'start': + event = vtk.vtkCommand.StartInteractionEvent + elif event_type == 'end': + event = vtk.vtkCommand.EndInteractionEvent + elif event_type == 'always': + event = vtk.vtkCommand.InteractionEvent + + slider.RemoveObserver(event) + slider.AddObserver(event, _the_callback) diff --git a/mne/viz/backends/renderer.py b/mne/viz/backends/renderer.py index f3a7d59f24c..3c9b367f40f 100644 --- a/mne/viz/backends/renderer.py +++ b/mne/viz/backends/renderer.py @@ -75,6 +75,8 @@ def set_3d_backend(backend_name): +--------------------------------------+--------+---------+ | :func:`snapshot_brain_montage` | ✓ | ✓ | +--------------------------------------+--------+---------+ + | :func:`link_brains` | | ✓ | + +--------------------------------------+--------+---------+ +--------------------------------------+--------+---------+ | **3D feature:** | +--------------------------------------+--------+---------+ diff --git a/mne/viz/tests/test_3d.py b/mne/viz/tests/test_3d.py index 4584a93b284..51745ab1dc5 100644 --- a/mne/viz/tests/test_3d.py +++ b/mne/viz/tests/test_3d.py @@ -29,7 +29,8 @@ from mne.viz import (plot_sparse_source_estimates, plot_source_estimates, snapshot_brain_montage, plot_head_positions, plot_alignment, plot_volume_source_estimates, - plot_sensors_connectivity, plot_brain_colorbar) + plot_sensors_connectivity, plot_brain_colorbar, + link_brains) from mne.viz.utils import _fake_click from mne.utils import (requires_mayavi, requires_pysurfer, run_tests_if_main, requires_nibabel, check_version, requires_dipy, @@ -665,4 +666,39 @@ def test_mixed_sources_plot_surface(): colorbar=False) +@testing.requires_testing_data +@traits_test +def test_link_brains(renderer): + """Test plotting linked brains.""" + if renderer.get_3d_backend() == "mayavi": + pytest.skip() # Skip PySurfer.TimeViewer + else: + # Disable testing to allow interactive window + renderer.MNE_3D_BACKEND_TESTING = False + with pytest.raises(ValueError, match='is empty'): + link_brains([]) + with pytest.raises(TypeError, match='type is Brain'): + link_brains('foo') + + sample_src = read_source_spaces(src_fname) + vertices = [s['vertno'] for s in sample_src] + n_time = 5 + n_verts = sum(len(v) for v in vertices) + stc_data = np.zeros((n_verts * n_time)) + stc_size = stc_data.size + stc_data[(np.random.rand(stc_size // 20) * stc_size).astype(int)] = \ + np.random.RandomState(0).rand(stc_data.size // 20) + stc_data.shape = (n_verts, n_time) + stc = SourceEstimate(stc_data, vertices, 1, 1) + + colormap = 'mne_analyze' + brain = plot_source_estimates( + stc, 'sample', colormap=colormap, + background=(1, 1, 0), + subjects_dir=subjects_dir, colorbar=True, + clim='auto' + ) + link_brains(brain) + + run_tests_if_main()