Skip to content

Commit

Permalink
MRG, ENH: Add a screenshot button to the notebook 3d backend (#8708)
Browse files Browse the repository at this point in the history
* Add screenshot button

* Ensure valid filename

* Make the name shorter

* Use Brain screenshot

* Click on all buttons

* Count the buttons too

* Add a tool bar to the standard _Renderer

* Improve testing of standard _Renderer

* DRY a little bit

* Use concatenate_images

* Make it shorter and more complicated

* Fix style

* Add centered parameter

* Comment slicing

* make it work on mac

* Remove cruft

* Update comments

* Remove more comments

* Generate screenshot filename

* Start over and test

* Test both qt and notebook

* The pragmatic approach

* Improve testing

* Fix test

* ENH: Faster test

* BUG: More explicit height

* Fix dangling objects issue

* Change order

* Try #8082

* FIX: Fix sizing

* FIX: Use concatenate_images

* FIX: dtype

* MAINT: Notebook test

* FIX: Flake

* Speed up test.ipynb

* FIX: Bad Qt/VTK combo

Co-authored-by: Alexandre Gramfort <[email protected]>
Co-authored-by: Eric Larson <[email protected]>
  • Loading branch information
3 people authored Jan 21, 2021
1 parent 94a75a5 commit 72c8f61
Show file tree
Hide file tree
Showing 10 changed files with 278 additions and 77 deletions.
2 changes: 1 addition & 1 deletion mne/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ def _data_path(path=None, force_update=False, update_path=True, download=True,
want_version = _FAKE_VERSION if name == 'fake' else want_version
if not need_download and want_version is not None:
data_version = _dataset_version(folder_path[0], name)
need_download = data_version != want_version
need_download = LooseVersion(data_version) < LooseVersion(want_version)
if need_download:
logger.info(f'Dataset {name} version {data_version} out of date, '
f'latest version is {want_version}')
Expand Down
4 changes: 3 additions & 1 deletion mne/label.py
Original file line number Diff line number Diff line change
Expand Up @@ -1874,9 +1874,11 @@ def _cortex_parcellation(subject, n_parcel, hemis, vertices_, graphs,
return labels


def _read_annot_cands(dir_name):
def _read_annot_cands(dir_name, raise_error=True):
"""List the candidate parcellations."""
if not op.isdir(dir_name):
if not raise_error:
return list()
raise IOError('Directory for annotation does not exist: %s',
dir_name)
cands = os.listdir(dir_name)
Expand Down
2 changes: 1 addition & 1 deletion mne/viz/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .topo import plot_topo_image_epochs, iter_topography
from .utils import (tight_layout, mne_analyze_colormap, compare_fiff,
ClickableImage, add_background_image, plot_sensors,
centers_to_edges)
centers_to_edges, concatenate_images)
from ._3d import (plot_sparse_source_estimates, plot_source_estimates,
plot_vector_source_estimates, plot_evoked_field,
plot_dipole_locations, snapshot_brain_montage,
Expand Down
142 changes: 83 additions & 59 deletions mne/viz/_brain/_brain.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import contextlib
from functools import partial
from io import BytesIO
import os
import os.path as op
import sys
Expand All @@ -27,7 +28,7 @@
from .callback import (ShowView, IntSlider, TimeSlider, SmartSlider,
BumpColorbarPoints, UpdateColorbarScale)

from ..utils import _show_help, _get_color_list
from ..utils import _show_help, _get_color_list, concatenate_images
from .._3d import _process_clim, _handle_time, _check_views

from ...externals.decorator import decorator
Expand Down Expand Up @@ -593,6 +594,7 @@ def setup_time_viewer(self, time_viewer=True, show_traces=True):
self._configure_picking()
self._configure_tool_bar()
if self.notebook:
self._renderer._set_tool_bar(state=False)
self.show()
self._configure_trace_mode()
self.toggle_interface()
Expand Down Expand Up @@ -655,9 +657,20 @@ def ensure_minimum_sizes(self):
yield
finally:
self.splitter.setSizes([sz[1], mpl_h])
# 1. Process events
_process_events(self.plotter)
_process_events(self.plotter)
self.mpl_canvas.canvas.setMinimumSize(0, 0)
# 2. Get the window size that accommodates the size
sz = self.plotter.app_window.size()
# 3. Call app_window.setBaseSize and resize (in pyvistaqt)
self.plotter.window_size = (sz.width(), sz.height())
# 4. Undo the min size setting and process events
self.plotter.interactor.setMinimumSize(0, 0)
_process_events(self.plotter)
_process_events(self.plotter)
# 5. Resize the window (again!) to the correct size
# (not sure why, but this is required on macOS at least)
self.plotter.window_size = (sz.width(), sz.height())
_process_events(self.plotter)
_process_events(self.plotter)
# sizes could change, update views
Expand Down Expand Up @@ -1145,15 +1158,16 @@ def _set_annot(annot):

from PyQt5.QtWidgets import QComboBox, QLabel
dir_name = op.join(self._subjects_dir, self._subject_id, 'label')
cands = _read_annot_cands(dir_name)
cands = _read_annot_cands(dir_name, raise_error=False)
self.tool_bar.addSeparator()
self.tool_bar.addWidget(QLabel("Annotation"))
self._annot_cands_widget = QComboBox()
self.tool_bar.addWidget(self._annot_cands_widget)
self._annot_cands_widget.addItem('None')
cands = cands + ['None']
for cand in cands:
self._annot_cands_widget.addItem(cand)
self.annot = cands[0]
del cands

# setup label extraction parameters
def _set_label_mode(mode):
Expand Down Expand Up @@ -1215,22 +1229,28 @@ def _save_movie_noname(self):
return self.save_movie(None)

def _screenshot(self):
if not self.notebook:
if self.notebook:
from PIL import Image
fname = self.actions.get("screenshot_field").value
fname = self._renderer._get_screenshot_filename() \
if len(fname) == 0 else fname
img = self.screenshot(fname, time_viewer=True)
Image.fromarray(img).save(fname)
else:
self.plotter._qt_screenshot()

def _initialize_actions(self):
if not self.notebook:
self._load_icons()
self.tool_bar = self.window.addToolBar("toolbar")

def _add_action(self, name, desc, func, icon_name, qt_icon_name=None,
def _add_button(self, name, desc, func, icon_name, qt_icon_name=None,
notebook=True):
if self.notebook:
if not notebook:
return
from ipywidgets import Button
self.actions[name] = Button(description=desc, icon=icon_name)
self.actions[name].on_click(lambda x: func())
self.actions[name] = self._renderer._add_button(
desc, func, icon_name)
else:
qt_icon_name = name if qt_icon_name is None else qt_icon_name
self.actions[name] = self.tool_bar.addAction(
Expand All @@ -1239,61 +1259,71 @@ def _add_action(self, name, desc, func, icon_name, qt_icon_name=None,
func,
)

def _add_text_field(self, name, value, placeholder):
if not self.notebook:
return
self.actions[name] = self._renderer._add_text_field(
value, placeholder)

def _configure_tool_bar(self):
self._initialize_actions()
self._add_action(
self._add_button(
name="screenshot",
desc="Take a screenshot",
func=self._screenshot,
icon_name=None,
notebook=False,
icon_name="camera",
)
self._add_text_field(
name="screenshot_field",
value=None,
placeholder="Type a file name",
)
self._add_action(
self._add_button(
name="movie",
desc="Save movie...",
func=self._save_movie_noname,
icon_name=None,
notebook=False,
)
self._add_action(
self._add_button(
name="visibility",
desc="Toggle Visibility",
func=self.toggle_interface,
icon_name="eye",
qt_icon_name="visibility_on",
)
self._add_action(
self._add_button(
name="play",
desc="Play/Pause",
func=self.toggle_playback,
icon_name=None,
notebook=False,
)
self._add_action(
self._add_button(
name="reset",
desc="Reset",
func=self.reset,
icon_name="history",
)
self._add_action(
self._add_button(
name="scale",
desc="Auto-Scale",
func=self.apply_auto_scaling,
icon_name="magic",
)
self._add_action(
self._add_button(
name="restore",
desc="Restore scaling",
func=self.restore_user_scaling,
icon_name="reply",
)
self._add_action(
self._add_button(
name="clear",
desc="Clear traces",
func=self.clear_glyphs,
icon_name="trash",
)
self._add_action(
self._add_button(
name="help",
desc="Help",
func=self.help,
Expand All @@ -1302,10 +1332,7 @@ def _configure_tool_bar(self):
)

if self.notebook:
from IPython import display
from ipywidgets import HBox
self.tool_bar = HBox(tuple(self.actions.values()))
display.display(self.tool_bar)
self.tool_bar = self._renderer._show_tool_bar(self.actions)
else:
# Qt shortcuts
self.actions["movie"].setShortcut("ctrl+shift+s")
Expand Down Expand Up @@ -1593,6 +1620,7 @@ def plot_time_course(self, hemi, vertex_id, color):
if self.mpl_canvas is None:
return
time = self._data['time'].copy() # avoid circular ref
mni = None
if hemi == 'vol':
hemi_str = 'V'
xfm = read_talxfm(
Expand All @@ -1605,15 +1633,20 @@ def plot_time_course(self, hemi, vertex_id, color):
mni = apply_trans(np.dot(xfm['trans'], src_mri_t), ijk)
else:
hemi_str = 'L' if hemi == 'lh' else 'R'
mni = vertex_to_mni(
vertices=vertex_id,
hemis=0 if hemi == 'lh' else 1,
subject=self._subject_id,
subjects_dir=self._subjects_dir
)
label = "{}:{} MNI: {}".format(
hemi_str, str(vertex_id).ljust(6),
', '.join('%5.1f' % m for m in mni))
try:
mni = vertex_to_mni(
vertices=vertex_id,
hemis=0 if hemi == 'lh' else 1,
subject=self._subject_id,
subjects_dir=self._subjects_dir
)
except Exception:
mni = None
if mni is not None:
mni = ' MNI: ' + ', '.join('%5.1f' % m for m in mni)
else:
mni = ''
label = "{}:{}{}".format(hemi_str, str(vertex_id).ljust(6), mni)
act_data, smooth = self.act_data_smooth[hemi]
if smooth is not None:
act_data = smooth[vertex_id].dot(act_data)[0]
Expand Down Expand Up @@ -2594,32 +2627,23 @@ def screenshot(self, mode='rgb', time_viewer=False):
not self.separate_canvas:
canvas = self.mpl_canvas.fig.canvas
canvas.draw_idle()
# In theory, one of these should work:
#
# trace_img = np.frombuffer(
# canvas.tostring_rgb(), dtype=np.uint8)
# trace_img.shape = canvas.get_width_height()[::-1] + (3,)
#
# or
#
# trace_img = np.frombuffer(
# canvas.tostring_rgb(), dtype=np.uint8)
# size = time_viewer.mpl_canvas.getSize()
# trace_img.shape = (size.height(), size.width(), 3)
#
# But in practice, sometimes the sizes does not match the
# renderer tostring_rgb() size. So let's directly use what
# matplotlib does in lib/matplotlib/backends/backend_agg.py
# before calling tobytes():
trace_img = np.asarray(
canvas.renderer._renderer).take([0, 1, 2], axis=2)
# need to slice into trace_img because generally it's a bit
# smaller
delta = trace_img.shape[1] - img.shape[1]
if delta > 0:
start = delta // 2
trace_img = trace_img[:, start:start + img.shape[1]]
img = np.concatenate([img, trace_img], axis=0)
fig = self.mpl_canvas.fig
with BytesIO() as output:
# Need to pass dpi here so it uses the physical (HiDPI) DPI
# rather than logical DPI when saving in most cases.
# But when matplotlib uses HiDPI and VTK doesn't
# (e.g., macOS w/Qt 5.14+ and VTK9) then things won't work,
# so let's just calculate the DPI we need to get
# the correct size output based on the widths being equal
dpi = img.shape[1] / fig.get_size_inches()[0]
fig.savefig(output, dpi=dpi, format='raw',
facecolor=self._bg_color, edgecolor='none')
output.seek(0)
trace_img = np.reshape(
np.frombuffer(output.getvalue(), dtype=np.uint8),
newshape=(-1, img.shape[1], 4))[:, :, :3]
img = concatenate_images(
[img, trace_img], bgcolor=self._brain_color[:3])
return img

@fill_doc
Expand Down
44 changes: 39 additions & 5 deletions mne/viz/_brain/tests/test.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,12 @@
"metadata": {},
"outputs": [],
"source": [
"from contextlib import contextmanager\n",
"import os\n",
"import mne\n",
"from numpy.testing import assert_allclose\n",
"from ipywidgets import Button\n",
"import matplotlib.pyplot as plt\n",
"import mne\n",
"from mne.datasets import testing\n",
"data_path = testing.data_path()\n",
"sample_dir = os.path.join(data_path, 'MEG', 'sample')\n",
Expand All @@ -39,16 +42,40 @@
"initial_time = 0.13\n",
"mne.viz.set_3d_backend('notebook')\n",
"brain_class = mne.viz.get_brain_class()\n",
"for interactive_state in (False, True):\n",
" plt.interactive(interactive_state)\n",
"\n",
"\n",
"@contextmanager\n",
"def interactive(on):\n",
" old = plt.isinteractive()\n",
" plt.interactive(on)\n",
" try:\n",
" yield\n",
" finally:\n",
" plt.interactive(old)\n",
"\n",
"with interactive(False):\n",
" brain = stc.plot(subjects_dir=subjects_dir, initial_time=initial_time,\n",
" clim=dict(kind='value', pos_lims=[3, 6, 9]),\n",
" time_viewer=True,\n",
" hemi='split')\n",
" show_traces=True,\n",
" hemi='lh', size=300)\n",
" assert isinstance(brain, brain_class)\n",
" assert brain.notebook\n",
" assert brain._renderer.figure.display is not None\n",
" brain._update()\n",
" total_number_of_buttons = len([k for k in brain.actions.keys() if '_field' not in k])\n",
" number_of_buttons = 0\n",
" for action in brain.actions.values():\n",
" if isinstance(action, Button):\n",
" action.click()\n",
" number_of_buttons += 1\n",
" assert number_of_buttons == total_number_of_buttons\n",
" img_nv = brain.screenshot()\n",
" assert img_nv.shape == (300, 300, 3), img_nv.shape\n",
" img_v = brain.screenshot(time_viewer=True)\n",
" assert img_v.shape[1:] == (300, 3), img_v.shape\n",
" # XXX This rtol is not very good, ideally would be zero\n",
" assert_allclose(img_v.shape[0], img_nv.shape[0] * 1.25, err_msg=img_nv.shape, rtol=0.1)\n",
" brain.close()"
]
},
Expand All @@ -66,6 +93,13 @@
"mne.viz.set_3d_view(fig, 200, 70, focalpoint=[0, 0, 0])\n",
"assert fig.display is None\n",
"rend.show()\n",
"total_number_of_buttons = len([k for k in rend.actions.keys() if '_field' not in k])\n",
"number_of_buttons = 0\n",
"for action in rend.actions.values():\n",
" if isinstance(action, Button):\n",
" action.click()\n",
" number_of_buttons += 1\n",
"assert number_of_buttons == total_number_of_buttons\n",
"assert fig.display is not None"
]
}
Expand All @@ -86,4 +120,4 @@
},
"nbformat": 4,
"nbformat_minor": 4
}
}
Loading

0 comments on commit 72c8f61

Please sign in to comment.