diff --git a/mne/datasets/utils.py b/mne/datasets/utils.py index 506e31ebfe2..2320a5ac39c 100644 --- a/mne/datasets/utils.py +++ b/mne/datasets/utils.py @@ -383,7 +383,7 @@ def _data_path(path=None, force_update=False, update_path=True, download=True, want_version = _FAKE_VERSION if name == 'fake' else want_version if not need_download and want_version is not None: data_version = _dataset_version(folder_path[0], name) - need_download = data_version != want_version + need_download = LooseVersion(data_version) < LooseVersion(want_version) if need_download: logger.info(f'Dataset {name} version {data_version} out of date, ' f'latest version is {want_version}') diff --git a/mne/label.py b/mne/label.py index e0e6a8ccddd..e596b68c7b3 100644 --- a/mne/label.py +++ b/mne/label.py @@ -1874,9 +1874,11 @@ def _cortex_parcellation(subject, n_parcel, hemis, vertices_, graphs, return labels -def _read_annot_cands(dir_name): +def _read_annot_cands(dir_name, raise_error=True): """List the candidate parcellations.""" if not op.isdir(dir_name): + if not raise_error: + return list() raise IOError('Directory for annotation does not exist: %s', dir_name) cands = os.listdir(dir_name) diff --git a/mne/viz/__init__.py b/mne/viz/__init__.py index 9b54741afe7..a64bb4efdb9 100644 --- a/mne/viz/__init__.py +++ b/mne/viz/__init__.py @@ -6,7 +6,7 @@ from .topo import plot_topo_image_epochs, iter_topography from .utils import (tight_layout, mne_analyze_colormap, compare_fiff, ClickableImage, add_background_image, plot_sensors, - centers_to_edges) + centers_to_edges, concatenate_images) from ._3d import (plot_sparse_source_estimates, plot_source_estimates, plot_vector_source_estimates, plot_evoked_field, plot_dipole_locations, snapshot_brain_montage, diff --git a/mne/viz/_brain/_brain.py b/mne/viz/_brain/_brain.py index 142c0320192..a3bff001192 100644 --- a/mne/viz/_brain/_brain.py +++ b/mne/viz/_brain/_brain.py @@ -9,6 +9,7 @@ import contextlib from functools import partial +from io import BytesIO import os import os.path as op import sys @@ -27,7 +28,7 @@ from .callback import (ShowView, IntSlider, TimeSlider, SmartSlider, BumpColorbarPoints, UpdateColorbarScale) -from ..utils import _show_help, _get_color_list +from ..utils import _show_help, _get_color_list, concatenate_images from .._3d import _process_clim, _handle_time, _check_views from ...externals.decorator import decorator @@ -593,6 +594,7 @@ def setup_time_viewer(self, time_viewer=True, show_traces=True): self._configure_picking() self._configure_tool_bar() if self.notebook: + self._renderer._set_tool_bar(state=False) self.show() self._configure_trace_mode() self.toggle_interface() @@ -655,9 +657,20 @@ def ensure_minimum_sizes(self): yield finally: self.splitter.setSizes([sz[1], mpl_h]) + # 1. Process events _process_events(self.plotter) _process_events(self.plotter) - self.mpl_canvas.canvas.setMinimumSize(0, 0) + # 2. Get the window size that accommodates the size + sz = self.plotter.app_window.size() + # 3. Call app_window.setBaseSize and resize (in pyvistaqt) + self.plotter.window_size = (sz.width(), sz.height()) + # 4. Undo the min size setting and process events + self.plotter.interactor.setMinimumSize(0, 0) + _process_events(self.plotter) + _process_events(self.plotter) + # 5. Resize the window (again!) to the correct size + # (not sure why, but this is required on macOS at least) + self.plotter.window_size = (sz.width(), sz.height()) _process_events(self.plotter) _process_events(self.plotter) # sizes could change, update views @@ -1145,15 +1158,16 @@ def _set_annot(annot): from PyQt5.QtWidgets import QComboBox, QLabel dir_name = op.join(self._subjects_dir, self._subject_id, 'label') - cands = _read_annot_cands(dir_name) + cands = _read_annot_cands(dir_name, raise_error=False) self.tool_bar.addSeparator() self.tool_bar.addWidget(QLabel("Annotation")) self._annot_cands_widget = QComboBox() self.tool_bar.addWidget(self._annot_cands_widget) - self._annot_cands_widget.addItem('None') + cands = cands + ['None'] for cand in cands: self._annot_cands_widget.addItem(cand) self.annot = cands[0] + del cands # setup label extraction parameters def _set_label_mode(mode): @@ -1215,7 +1229,14 @@ def _save_movie_noname(self): return self.save_movie(None) def _screenshot(self): - if not self.notebook: + if self.notebook: + from PIL import Image + fname = self.actions.get("screenshot_field").value + fname = self._renderer._get_screenshot_filename() \ + if len(fname) == 0 else fname + img = self.screenshot(fname, time_viewer=True) + Image.fromarray(img).save(fname) + else: self.plotter._qt_screenshot() def _initialize_actions(self): @@ -1223,14 +1244,13 @@ def _initialize_actions(self): self._load_icons() self.tool_bar = self.window.addToolBar("toolbar") - def _add_action(self, name, desc, func, icon_name, qt_icon_name=None, + def _add_button(self, name, desc, func, icon_name, qt_icon_name=None, notebook=True): if self.notebook: if not notebook: return - from ipywidgets import Button - self.actions[name] = Button(description=desc, icon=icon_name) - self.actions[name].on_click(lambda x: func()) + self.actions[name] = self._renderer._add_button( + desc, func, icon_name) else: qt_icon_name = name if qt_icon_name is None else qt_icon_name self.actions[name] = self.tool_bar.addAction( @@ -1239,61 +1259,71 @@ def _add_action(self, name, desc, func, icon_name, qt_icon_name=None, func, ) + def _add_text_field(self, name, value, placeholder): + if not self.notebook: + return + self.actions[name] = self._renderer._add_text_field( + value, placeholder) + def _configure_tool_bar(self): self._initialize_actions() - self._add_action( + self._add_button( name="screenshot", desc="Take a screenshot", func=self._screenshot, - icon_name=None, - notebook=False, + icon_name="camera", + ) + self._add_text_field( + name="screenshot_field", + value=None, + placeholder="Type a file name", ) - self._add_action( + self._add_button( name="movie", desc="Save movie...", func=self._save_movie_noname, icon_name=None, notebook=False, ) - self._add_action( + self._add_button( name="visibility", desc="Toggle Visibility", func=self.toggle_interface, icon_name="eye", qt_icon_name="visibility_on", ) - self._add_action( + self._add_button( name="play", desc="Play/Pause", func=self.toggle_playback, icon_name=None, notebook=False, ) - self._add_action( + self._add_button( name="reset", desc="Reset", func=self.reset, icon_name="history", ) - self._add_action( + self._add_button( name="scale", desc="Auto-Scale", func=self.apply_auto_scaling, icon_name="magic", ) - self._add_action( + self._add_button( name="restore", desc="Restore scaling", func=self.restore_user_scaling, icon_name="reply", ) - self._add_action( + self._add_button( name="clear", desc="Clear traces", func=self.clear_glyphs, icon_name="trash", ) - self._add_action( + self._add_button( name="help", desc="Help", func=self.help, @@ -1302,10 +1332,7 @@ def _configure_tool_bar(self): ) if self.notebook: - from IPython import display - from ipywidgets import HBox - self.tool_bar = HBox(tuple(self.actions.values())) - display.display(self.tool_bar) + self.tool_bar = self._renderer._show_tool_bar(self.actions) else: # Qt shortcuts self.actions["movie"].setShortcut("ctrl+shift+s") @@ -1593,6 +1620,7 @@ def plot_time_course(self, hemi, vertex_id, color): if self.mpl_canvas is None: return time = self._data['time'].copy() # avoid circular ref + mni = None if hemi == 'vol': hemi_str = 'V' xfm = read_talxfm( @@ -1605,15 +1633,20 @@ def plot_time_course(self, hemi, vertex_id, color): mni = apply_trans(np.dot(xfm['trans'], src_mri_t), ijk) else: hemi_str = 'L' if hemi == 'lh' else 'R' - mni = vertex_to_mni( - vertices=vertex_id, - hemis=0 if hemi == 'lh' else 1, - subject=self._subject_id, - subjects_dir=self._subjects_dir - ) - label = "{}:{} MNI: {}".format( - hemi_str, str(vertex_id).ljust(6), - ', '.join('%5.1f' % m for m in mni)) + try: + mni = vertex_to_mni( + vertices=vertex_id, + hemis=0 if hemi == 'lh' else 1, + subject=self._subject_id, + subjects_dir=self._subjects_dir + ) + except Exception: + mni = None + if mni is not None: + mni = ' MNI: ' + ', '.join('%5.1f' % m for m in mni) + else: + mni = '' + label = "{}:{}{}".format(hemi_str, str(vertex_id).ljust(6), mni) act_data, smooth = self.act_data_smooth[hemi] if smooth is not None: act_data = smooth[vertex_id].dot(act_data)[0] @@ -2594,32 +2627,23 @@ def screenshot(self, mode='rgb', time_viewer=False): not self.separate_canvas: canvas = self.mpl_canvas.fig.canvas canvas.draw_idle() - # In theory, one of these should work: - # - # trace_img = np.frombuffer( - # canvas.tostring_rgb(), dtype=np.uint8) - # trace_img.shape = canvas.get_width_height()[::-1] + (3,) - # - # or - # - # trace_img = np.frombuffer( - # canvas.tostring_rgb(), dtype=np.uint8) - # size = time_viewer.mpl_canvas.getSize() - # trace_img.shape = (size.height(), size.width(), 3) - # - # But in practice, sometimes the sizes does not match the - # renderer tostring_rgb() size. So let's directly use what - # matplotlib does in lib/matplotlib/backends/backend_agg.py - # before calling tobytes(): - trace_img = np.asarray( - canvas.renderer._renderer).take([0, 1, 2], axis=2) - # need to slice into trace_img because generally it's a bit - # smaller - delta = trace_img.shape[1] - img.shape[1] - if delta > 0: - start = delta // 2 - trace_img = trace_img[:, start:start + img.shape[1]] - img = np.concatenate([img, trace_img], axis=0) + fig = self.mpl_canvas.fig + with BytesIO() as output: + # Need to pass dpi here so it uses the physical (HiDPI) DPI + # rather than logical DPI when saving in most cases. + # But when matplotlib uses HiDPI and VTK doesn't + # (e.g., macOS w/Qt 5.14+ and VTK9) then things won't work, + # so let's just calculate the DPI we need to get + # the correct size output based on the widths being equal + dpi = img.shape[1] / fig.get_size_inches()[0] + fig.savefig(output, dpi=dpi, format='raw', + facecolor=self._bg_color, edgecolor='none') + output.seek(0) + trace_img = np.reshape( + np.frombuffer(output.getvalue(), dtype=np.uint8), + newshape=(-1, img.shape[1], 4))[:, :, :3] + img = concatenate_images( + [img, trace_img], bgcolor=self._brain_color[:3]) return img @fill_doc diff --git a/mne/viz/_brain/tests/test.ipynb b/mne/viz/_brain/tests/test.ipynb index 80a8bec809e..ec7bfc13e60 100644 --- a/mne/viz/_brain/tests/test.ipynb +++ b/mne/viz/_brain/tests/test.ipynb @@ -27,9 +27,12 @@ "metadata": {}, "outputs": [], "source": [ + "from contextlib import contextmanager\n", "import os\n", - "import mne\n", + "from numpy.testing import assert_allclose\n", + "from ipywidgets import Button\n", "import matplotlib.pyplot as plt\n", + "import mne\n", "from mne.datasets import testing\n", "data_path = testing.data_path()\n", "sample_dir = os.path.join(data_path, 'MEG', 'sample')\n", @@ -39,16 +42,40 @@ "initial_time = 0.13\n", "mne.viz.set_3d_backend('notebook')\n", "brain_class = mne.viz.get_brain_class()\n", - "for interactive_state in (False, True):\n", - " plt.interactive(interactive_state)\n", + "\n", + "\n", + "@contextmanager\n", + "def interactive(on):\n", + " old = plt.isinteractive()\n", + " plt.interactive(on)\n", + " try:\n", + " yield\n", + " finally:\n", + " plt.interactive(old)\n", + "\n", + "with interactive(False):\n", " brain = stc.plot(subjects_dir=subjects_dir, initial_time=initial_time,\n", " clim=dict(kind='value', pos_lims=[3, 6, 9]),\n", " time_viewer=True,\n", - " hemi='split')\n", + " show_traces=True,\n", + " hemi='lh', size=300)\n", " assert isinstance(brain, brain_class)\n", " assert brain.notebook\n", " assert brain._renderer.figure.display is not None\n", " brain._update()\n", + " total_number_of_buttons = len([k for k in brain.actions.keys() if '_field' not in k])\n", + " number_of_buttons = 0\n", + " for action in brain.actions.values():\n", + " if isinstance(action, Button):\n", + " action.click()\n", + " number_of_buttons += 1\n", + " assert number_of_buttons == total_number_of_buttons\n", + " img_nv = brain.screenshot()\n", + " assert img_nv.shape == (300, 300, 3), img_nv.shape\n", + " img_v = brain.screenshot(time_viewer=True)\n", + " assert img_v.shape[1:] == (300, 3), img_v.shape\n", + " # XXX This rtol is not very good, ideally would be zero\n", + " assert_allclose(img_v.shape[0], img_nv.shape[0] * 1.25, err_msg=img_nv.shape, rtol=0.1)\n", " brain.close()" ] }, @@ -66,6 +93,13 @@ "mne.viz.set_3d_view(fig, 200, 70, focalpoint=[0, 0, 0])\n", "assert fig.display is None\n", "rend.show()\n", + "total_number_of_buttons = len([k for k in rend.actions.keys() if '_field' not in k])\n", + "number_of_buttons = 0\n", + "for action in rend.actions.values():\n", + " if isinstance(action, Button):\n", + " action.click()\n", + " number_of_buttons += 1\n", + "assert number_of_buttons == total_number_of_buttons\n", "assert fig.display is not None" ] } @@ -86,4 +120,4 @@ }, "nbformat": 4, "nbformat_minor": 4 -} +} \ No newline at end of file diff --git a/mne/viz/_brain/tests/test_brain.py b/mne/viz/_brain/tests/test_brain.py index 6a0897bf12b..e206983d211 100644 --- a/mne/viz/_brain/tests/test_brain.py +++ b/mne/viz/_brain/tests/test_brain.py @@ -17,7 +17,7 @@ from mne import (read_source_estimate, read_evokeds, read_cov, read_forward_solution, pick_types_forward, - SourceEstimate, MixedSourceEstimate, + SourceEstimate, MixedSourceEstimate, write_surface, VolSourceEstimate) from mne.minimum_norm import apply_inverse, make_inverse_operator from mne.source_space import (read_source_spaces, vertex_to_mni, @@ -369,6 +369,57 @@ def test_brain_save_movie(tmpdir, renderer, brain_gc): brain.close() +_TINY_SIZE = (300, 250) + + +def tiny(tmpdir): + """Create a tiny fake brain.""" + # This is a minimal version of what we need for our viz-with-timeviewer + # support currently + subject = 'test' + subject_dir = tmpdir.mkdir(subject) + surf_dir = subject_dir.mkdir('surf') + rng = np.random.RandomState(0) + rr = rng.randn(4, 3) + tris = np.array([[0, 1, 2], [2, 1, 3]]) + curv = rng.randn(len(rr)) + with open(surf_dir.join('lh.curv'), 'wb') as fid: + fid.write(np.array([255, 255, 255], dtype=np.uint8)) + fid.write(np.array([len(rr), 0, 1], dtype='>i4')) + fid.write(curv.astype('>f4')) + write_surface(surf_dir.join('lh.white'), rr, tris) + write_surface(surf_dir.join('rh.white'), rr, tris) # needed for vertex tc + vertices = [np.arange(len(rr)), []] + data = rng.randn(len(rr), 10) + stc = SourceEstimate(data, vertices, 0, 1, subject) + brain = stc.plot(subjects_dir=tmpdir, hemi='lh', surface='white', + size=_TINY_SIZE) + # in principle this should be sufficient: + # + # ratio = brain.mpl_canvas.canvas.window().devicePixelRatio() + # + # but in practice VTK can mess up sizes, so let's just calculate it. + sz = brain.plotter.size() + sz = (sz.width(), sz.height()) + sz_ren = brain.plotter.renderer.GetSize() + ratio = np.median(np.array(sz_ren) / np.array(sz)) + return brain, ratio + + +def test_brain_screenshot(renderer_interactive, tmpdir, brain_gc): + """Test time viewer screenshot.""" + if renderer_interactive._get_3d_backend() != 'pyvista': + pytest.skip('TimeViewer tests only supported on PyVista') + tiny_brain, ratio = tiny(tmpdir) + img_nv = tiny_brain.screenshot(time_viewer=False) + want = (_TINY_SIZE[1] * ratio, _TINY_SIZE[0] * ratio, 3) + assert img_nv.shape == want + img_v = tiny_brain.screenshot(time_viewer=True) + assert img_v.shape[1:] == want[1:] + assert_allclose(img_v.shape[0], want[0] * 4 / 3, atol=3) # some slop + tiny_brain.close() + + @testing.requires_testing_data @pytest.mark.slowtest def test_brain_time_viewer(renderer_interactive, pixel_ratio, brain_gc): diff --git a/mne/viz/_brain/tests/test_notebook.py b/mne/viz/_brain/tests/test_notebook.py index 48c65c2d066..7c159326b74 100644 --- a/mne/viz/_brain/tests/test_notebook.py +++ b/mne/viz/_brain/tests/test_notebook.py @@ -1,5 +1,4 @@ import os -import pytest from mne.datasets import testing from mne.utils import requires_version @@ -7,7 +6,6 @@ PATH = os.path.dirname(os.path.realpath(__file__)) -@pytest.mark.slowtest @testing.requires_testing_data @requires_version('nbformat') @requires_version('nbclient') diff --git a/mne/viz/backends/_notebook.py b/mne/viz/backends/_notebook.py index 761f0b8a60f..e8bda5436d0 100644 --- a/mne/viz/backends/_notebook.py +++ b/mne/viz/backends/_notebook.py @@ -10,11 +10,53 @@ class _Renderer(_PyVistaRenderer): def __init__(self, *args, **kwargs): + self.tool_bar_state = True + self.tool_bar = None + self.actions = dict() kwargs["notebook"] = True super().__init__(*args, **kwargs) + def _screenshot(self): + fname = self.actions.get("screenshot_field").value + fname = self._get_screenshot_filename() if len(fname) == 0 else fname + self.screenshot(filename=fname) + + def _set_tool_bar(self, state): + self.tool_bar_state = state + + def _add_button(self, desc, func, icon_name): + from ipywidgets import Button + button = Button(tooltip=desc, icon=icon_name) + button.on_click(lambda x: func()) + return button + + def _add_text_field(self, value, placeholder): + from ipywidgets import Text + return Text(value=value, placeholder=placeholder) + + def _show_tool_bar(self, actions): + from IPython import display + from ipywidgets import HBox + tool_bar = HBox(tuple(actions.values())) + display.display(tool_bar) + return tool_bar + + def _configure_tool_bar(self): + self.actions["screenshot"] = self._add_button( + desc="Take a screenshot", + func=self._screenshot, + icon_name="camera", + ) + self.actions["screenshot_field"] = self._add_text_field( + value=None, + placeholder="Type a file name", + ) + self.tool_bar = self._show_tool_bar(self.actions) + def show(self): from IPython.display import display + if self.tool_bar_state: + self._configure_tool_bar() self.figure.display = self.plotter.show(use_ipyvtk=True, return_viewer=True) self.figure.display.layout.width = None # unlock the fixed layout diff --git a/mne/viz/backends/_pyvista.py b/mne/viz/backends/_pyvista.py index 1b340d2a974..51e2492777d 100644 --- a/mne/viz/backends/_pyvista.py +++ b/mne/viz/backends/_pyvista.py @@ -12,6 +12,7 @@ # License: Simplified BSD from contextlib import contextmanager +from datetime import datetime from distutils.version import LooseVersion import os import sys @@ -212,6 +213,11 @@ def __init__(self, fig=None, size=(600, 600), bgcolor='black', self.update_lighting() + def _get_screenshot_filename(self): + now = datetime.now() + dt_string = now.strftime("_%Y-%m-%d_%H-%M-%S") + return "MNE" + dt_string + ".png" + @contextmanager def ensure_minimum_sizes(self): sz = self.figure.store['window_size'] @@ -227,17 +233,17 @@ def ensure_minimum_sizes(self): # 1. Process events _process_events(self.plotter) _process_events(self.plotter) - # 2. Get the window size that accommodates the size - sz = self.plotter.app_window.size() - # 3. Call app_window.setBaseSize and resize (in pyvistaqt) - self.plotter.window_size = (sz.width(), sz.height()) - # 4. Undo the min size setting and process events + # 2. Get the window and interactor sizes that work + win_sz = self.plotter.app_window.size() + ren_sz = self.plotter.interactor.size() + # 3. Undo the min size setting and process events self.plotter.interactor.setMinimumSize(0, 0) _process_events(self.plotter) _process_events(self.plotter) - # 5. Resize the window (again!) to the correct size + # 4. Resize the window and interactor to the correct size # (not sure why, but this is required on macOS at least) - self.plotter.window_size = (sz.width(), sz.height()) + self.plotter.window_size = (win_sz.width(), win_sz.height()) + self.plotter.interactor.resize(ren_sz.width(), ren_sz.height()) _process_events(self.plotter) _process_events(self.plotter) diff --git a/mne/viz/utils.py b/mne/viz/utils.py index 515311b365d..5a09e7dfefb 100644 --- a/mne/viz/utils.py +++ b/mne/viz/utils.py @@ -2287,3 +2287,47 @@ def _ndarray_to_fig(img): fig = _figure_agg(dpi=dpi, figsize=figsize, frameon=False) fig.figimage(img, resize=True) return fig + + +def concatenate_images(images, axis=0, bgcolor='black', centered=True): + """Concatenate a list of images. + + Parameters + ---------- + images : list of ndarray + The list of images to concatenate. + axis : 0 or 1 + The images are concatenated horizontally if 0 and vertically otherwise. + The default orientation is horizontal. + bgcolor : str | list + The color of the background. The name of the color is accepted + (e.g 'red') or a list of RGB values between 0 and 1. Defaults to + 'black'. + centered : bool + If True, the images are centered. Defaults to True. + + Returns + ------- + img : ndarray + The concatenated image. + """ + from matplotlib.colors import colorConverter + if isinstance(bgcolor, str): + bgcolor = colorConverter.to_rgb(bgcolor) + bgcolor = np.asarray(bgcolor) * 255 + funcs = [np.sum, np.max] + ret_shape = np.asarray([ + funcs[axis]([image.shape[0] for image in images]), + funcs[1 - axis]([image.shape[1] for image in images]), + ]) + ret = np.zeros((ret_shape[0], ret_shape[1], 3), dtype=np.uint8) + ret[:, :, :] = bgcolor + ptr = np.array([0, 0]) + sec = np.array([0 == axis, 1 == axis]).astype(int) + for image in images: + shape = image.shape[:-1] + dec = ptr + dec += ((ret_shape - shape) // 2) * (1 - sec) if centered else 0 + ret[dec[0]:dec[0] + shape[0], dec[1]:dec[1] + shape[1], :] = image + ptr += shape * sec + return ret