diff --git a/mne/datasets/utils.py b/mne/datasets/utils.py index 5e74c10e912..2304861d871 100644 --- a/mne/datasets/utils.py +++ b/mne/datasets/utils.py @@ -239,7 +239,7 @@ def _data_path(path=None, force_update=False, update_path=True, download=True, path = _get_path(path, key, name) # To update the testing or misc dataset, push commits, then make a new # release on GitHub. Then update the "releases" variable: - releases = dict(testing='0.97', misc='0.6') + releases = dict(testing='0.98', misc='0.6') # And also update the "md5_hashes['testing']" variable below. # To update any other dataset, update the data archive itself (upload @@ -326,7 +326,7 @@ def _data_path(path=None, force_update=False, update_path=True, download=True, sample='12b75d1cb7df9dfb4ad73ed82f61094f', somato='ea825966c0a1e9b2f84e3826c5500161', spm='9f43f67150e3b694b523a21eb929ea75', - testing='603c3f087c4dbf151c729341342095c7', + testing='7c1dcfacaac7759aa40bfb800c791d85', multimodal='26ec847ae9ab80f58f204d09e2c08367', fnirs_motor='c4935d19ddab35422a69f3326a01fef8', opm='370ad1dcfd5c47e029e692c85358a374', diff --git a/mne/source_estimate.py b/mne/source_estimate.py index 9cae3eab622..11ffbc0ad3e 100644 --- a/mne/source_estimate.py +++ b/mne/source_estimate.py @@ -615,12 +615,12 @@ def plot(self, subject=None, surface='inflated', hemi='lh', colormap='auto', time_label='auto', smoothing_steps=10, transparent=True, alpha=1.0, time_viewer='auto', subjects_dir=None, - figure=None, views='lat', colorbar=True, clim='auto', + figure=None, views='auto', colorbar=True, clim='auto', cortex="classic", size=800, background="black", foreground=None, initial_time=None, time_unit='s', backend='auto', spacing='oct6', title=None, show_traces='auto', src=None, volume_options=1., view_layout='vertical', - verbose=None): + add_data_kwargs=None, verbose=None): brain = plot_source_estimates( self, subject, surface=surface, hemi=hemi, colormap=colormap, time_label=time_label, smoothing_steps=smoothing_steps, @@ -631,7 +631,7 @@ def plot(self, subject=None, surface='inflated', hemi='lh', initial_time=initial_time, time_unit=time_unit, backend=backend, spacing=spacing, title=title, show_traces=show_traces, src=src, volume_options=volume_options, view_layout=view_layout, - verbose=verbose) + add_data_kwargs=add_data_kwargs, verbose=verbose) return brain @property @@ -1904,11 +1904,13 @@ def project(self, directions, src=None, use_cps=True): def plot(self, subject=None, hemi='lh', colormap='hot', time_label='auto', smoothing_steps=10, transparent=True, brain_alpha=0.4, overlay_alpha=None, vector_alpha=1.0, scale_factor=None, - time_viewer='auto', subjects_dir=None, figure=None, views='lat', + time_viewer='auto', subjects_dir=None, figure=None, + views='lateral', colorbar=True, clim='auto', cortex='classic', size=800, background='black', foreground=None, initial_time=None, time_unit='s', show_traces='auto', src=None, volume_options=1., - view_layout='vertical', verbose=None): # noqa: D102 + view_layout='vertical', add_data_kwargs=None, + verbose=None): # noqa: D102 return plot_vector_source_estimates( self, subject=subject, hemi=hemi, colormap=colormap, time_label=time_label, smoothing_steps=smoothing_steps, @@ -1920,7 +1922,8 @@ def plot(self, subject=None, hemi='lh', colormap='hot', time_label='auto', background=background, foreground=foreground, initial_time=initial_time, time_unit=time_unit, show_traces=show_traces, src=src, volume_options=volume_options, - view_layout=view_layout, verbose=verbose) + view_layout=view_layout, add_data_kwargs=add_data_kwargs, + verbose=verbose) class _BaseVolSourceEstimate(_BaseSourceEstimate): @@ -1938,7 +1941,7 @@ def plot_3d(self, subject=None, surface='white', hemi='both', foreground=None, initial_time=None, time_unit='s', backend='auto', spacing='oct6', title=None, show_traces='auto', src=None, volume_options=1., view_layout='vertical', - verbose=None): + add_data_kwargs=None, verbose=None): return super().plot( subject=subject, surface=surface, hemi=hemi, colormap=colormap, time_label=time_label, smoothing_steps=smoothing_steps, @@ -1949,7 +1952,8 @@ def plot_3d(self, subject=None, surface='white', hemi='both', foreground=foreground, initial_time=initial_time, time_unit=time_unit, backend=backend, spacing=spacing, title=title, show_traces=show_traces, src=src, volume_options=volume_options, - view_layout=view_layout, verbose=verbose) + view_layout=view_layout, add_data_kwargs=add_data_kwargs, + verbose=verbose) @copy_function_doc_to_method_doc(plot_volume_source_estimates) def plot(self, src, subject=None, subjects_dir=None, mode='stat_map', @@ -2267,7 +2271,7 @@ def plot_3d(self, subject=None, hemi='both', colormap='hot', background='black', foreground=None, initial_time=None, time_unit='s', show_traces='auto', src=None, volume_options=1., view_layout='vertical', - verbose=None): # noqa: D102 + add_data_kwargs=None, verbose=None): # noqa: D102 return _BaseVectorSourceEstimate.plot( self, subject=subject, hemi=hemi, colormap=colormap, time_label=time_label, smoothing_steps=smoothing_steps, @@ -2279,7 +2283,8 @@ def plot_3d(self, subject=None, hemi='both', colormap='hot', background=background, foreground=foreground, initial_time=initial_time, time_unit=time_unit, show_traces=show_traces, src=src, volume_options=volume_options, - view_layout=view_layout, verbose=verbose) + view_layout=view_layout, add_data_kwargs=add_data_kwargs, + verbose=verbose) @fill_doc diff --git a/mne/surface.py b/mne/surface.py index a7a72b5c166..06d7d0d0327 100644 --- a/mne/surface.py +++ b/mne/surface.py @@ -764,6 +764,53 @@ def _read_wavefront_obj(fname): return np.array(coords), np.array(faces) +def _read_patch(fname): + """Load a FreeSurfer binary patch file. + + Parameters + ---------- + fname : str + The filename. + + Returns + ------- + rrs : ndarray, shape (n_vertices, 3) + The points. + tris : ndarray, shape (n_tris, 3) + The patches. Not all vertices will be present. + """ + # This is adapted from PySurfer PR #269, Bruce Fischl's read_patch.m, + # and PyCortex (BSD) + patch = dict() + with open(fname, 'r') as fid: + ver = np.fromfile(fid, dtype='>i4', count=1)[0] + if ver != -1: + raise RuntimeError(f'incorrect version # {ver} (not -1) found') + npts = np.fromfile(fid, dtype='>i4', count=1)[0] + dtype = np.dtype( + [('vertno', '>i4'), ('x', '>f'), ('y', '>f'), ('z', '>f')]) + recs = np.fromfile(fid, dtype=dtype, count=npts) + # numpy to dict + patch = {key: recs[key] for key in dtype.fields.keys()} + patch['vertno'] -= 1 + + # read surrogate surface + rrs, tris = read_surface( + op.join(op.dirname(fname), op.basename(fname)[:3] + 'sphere')) + orig_tris = tris + is_vert = patch['vertno'] > 0 # negative are edges, ignored for now + verts = patch['vertno'][is_vert] + + # eliminate invalid tris and zero out unused rrs + mask = np.zeros((len(rrs),), dtype=bool) + mask[verts] = True + rrs[~mask] = 0. + tris = tris[mask[tris].all(1)] + for ii, key in enumerate(['x', 'y', 'z']): + rrs[verts, ii] = patch[key][is_vert] + return rrs, tris, orig_tris + + ############################################################################## # SURFACE CREATION diff --git a/mne/tests/test_surface.py b/mne/tests/test_surface.py index fa25c98b033..0ada7494682 100644 --- a/mne/tests/test_surface.py +++ b/mne/tests/test_surface.py @@ -13,7 +13,7 @@ dig_mri_distances) from mne.surface import (read_morph_map, _compute_nearest, _tessellate_sphere, fast_cross_3d, get_head_surf, read_curvature, - get_meg_helmet_surf, _normal_orth) + get_meg_helmet_surf, _normal_orth, _read_patch) from mne.utils import (_TempDir, requires_vtk, catch_logging, run_tests_if_main, object_diff, requires_freesurfer) from mne.io import read_info @@ -183,6 +183,12 @@ def test_io_surface(): assert_array_equal(pts, c_pts) assert_array_equal(tri, c_tri) + # reading patches (just a smoke test, let the flatmap viz tests be more + # complete) + fname_patch = op.join( + data_path, 'subjects', 'fsaverage', 'surf', 'rh.cortex.patch.flat') + _read_patch(fname_patch) + @testing.requires_testing_data def test_read_curv(): diff --git a/mne/utils/docs.py b/mne/utils/docs.py index 3eacb38b566..265a575aa8b 100644 --- a/mne/utils/docs.py +++ b/mne/utils/docs.py @@ -1058,6 +1058,21 @@ Can be "vertical" (default) or "horizontal". When using "horizontal" mode, the PyVista backend must be used and hemi cannot be "split". """ +docdict['add_data_kwargs'] = """ +add_data_kwargs : dict | None + Additional arguments to brain.add_data (e.g., + ``dict(time_label_size=10)``). +""" +docdict['views'] = """ +views : str | list + View to use. Can be any of:: + + ['lateral', 'medial', 'rostral', 'caudal', 'dorsal', 'ventral', + 'frontal', 'parietal', 'axial', 'sagittal', 'coronal'] + + Three letter abbreviations (e.g., ``'lat'``) are also supported. + Using multiple views (list) is not supported for mpl backend. +""" # STC label time course docdict['eltc_labels'] = """ diff --git a/mne/viz/_3d.py b/mne/viz/_3d.py index 4c89132019d..a7890f5d163 100644 --- a/mne/viz/_3d.py +++ b/mne/viz/_3d.py @@ -1426,6 +1426,7 @@ def _plot_mpl_stc(stc, subject=None, surface='inflated', hemi='lh', 'par': {'elev': 30, 'azim': -60}} time_viewer = False if time_viewer == 'auto' else time_viewer kwargs = dict(lh=lh_kwargs, rh=rh_kwargs) + views = 'lat' if views == 'auto' else views _check_option('views', views, sorted(lh_kwargs.keys())) mapdata = _process_clim(clim, colormap, transparent, stc.data) _separate_map(mapdata) @@ -1582,13 +1583,13 @@ def plot_source_estimates(stc, subject=None, surface='inflated', hemi='lh', colormap='auto', time_label='auto', smoothing_steps=10, transparent=True, alpha=1.0, time_viewer='auto', subjects_dir=None, figure=None, - views='lat', colorbar=True, clim='auto', + views='auto', colorbar=True, clim='auto', cortex="classic", size=800, background="black", foreground=None, initial_time=None, time_unit='s', backend='auto', spacing='oct6', title=None, show_traces='auto', src=None, volume_options=1., view_layout='vertical', - verbose=None): + add_data_kwargs=None, verbose=None): """Plot SourceEstimate. Parameters @@ -1629,10 +1630,14 @@ def plot_source_estimates(stc, subject=None, surface='inflated', hemi='lh', length. If int is provided it will be used to identify the Mayavi figure by it's id or create a new figure with the given id. If an instance of matplotlib figure, mpl backend is used for plotting. - views : str | list - View to use. See `surfer.Brain`. Supported views: ['lat', 'med', 'ros', - 'cau', 'dor' 'ven', 'fro', 'par']. Using multiple views is not - supported for mpl backend. + %(views)s + + When plotting a standard SourceEstimate (not volume, mixed, or vector) + and using the PyVista backend, ``views='flat'`` is also supported to + plot cortex as a flatmap. + + .. versionchanged:: 0.21.0 + Support for flatmaps. colorbar : bool If True, display colorbar on scene. %(clim)s @@ -1676,6 +1681,7 @@ def plot_source_estimates(stc, subject=None, surface='inflated', hemi='lh', .. versionadded:: 0.17.0 %(show_traces)s %(src_volume_options_layout)s + %(add_data_kwargs)s %(verbose)s Returns @@ -1683,6 +1689,17 @@ def plot_source_estimates(stc, subject=None, surface='inflated', hemi='lh', figure : instance of surfer.Brain | matplotlib.figure.Figure An instance of :class:`surfer.Brain` from PySurfer or matplotlib figure. + + Notes + ----- + Flatmaps are available by default for ``fsaverage`` but not for other + subjects reconstructed by FreeSurfer. We recommend using + :func:`mne.compute_source_morph` to morph source estimates to ``fsaverage`` + for flatmap plotting. If you want to construct your own flatmap for a given + subject, these links might help: + + - https://surfer.nmr.mgh.harvard.edu/fswiki/FreeSurferOccipitalFlattenedPatch + - https://openwetware.org/wiki/Beauchamp:FreeSurfer """ # noqa: E501 from .backends.renderer import _get_3d_backend, set_3d_backend from ..source_estimate import _BaseSourceEstimate @@ -1716,7 +1733,7 @@ def plot_source_estimates(stc, subject=None, surface='inflated', hemi='lh', stc, overlay_alpha=alpha, brain_alpha=alpha, vector_alpha=alpha, cortex=cortex, foreground=foreground, size=size, scale_factor=None, show_traces=show_traces, src=src, volume_options=volume_options, - view_layout=view_layout, **kwargs) + view_layout=view_layout, add_data_kwargs=add_data_kwargs, **kwargs) def _plot_stc(stc, subject, surface, hemi, colormap, time_label, @@ -1724,7 +1741,7 @@ def _plot_stc(stc, subject, surface, hemi, colormap, time_label, time_unit, background, time_viewer, colorbar, transparent, brain_alpha, overlay_alpha, vector_alpha, cortex, foreground, size, scale_factor, show_traces, src, volume_options, - view_layout): + view_layout, add_data_kwargs): from .backends.renderer import _get_3d_backend vec = stc._data_ndim == 3 subjects_dir = get_subjects_dir(subjects_dir=subjects_dir, @@ -1739,7 +1756,7 @@ def _plot_stc(stc, subject, surface, hemi, colormap, time_label, _require_version('surfer', 'stc.plot', '0.9') else: # PyVista from ._brain import _Brain as Brain - + views = _check_views(surface, views, hemi, stc, backend) _check_option('hemi', hemi, ['lh', 'rh', 'split', 'both']) _check_option('view_layout', view_layout, ('vertical', 'horizontal')) time_label, times = _handle_time(time_label, time_unit, stc.times) @@ -1842,6 +1859,7 @@ def _plot_stc(stc, subject, surface, hemi, colormap, time_label, kwargs["clim"] = clim kwargs["volume_options"] = volume_options kwargs["src"] = src_vol + kwargs.update({} if add_data_kwargs is None else add_data_kwargs) with warnings.catch_warnings(record=True): # traits warnings brain.add_data(**kwargs) brain.scale_data_colormap(fmin=scale_pts[0], fmid=scale_pts[1], @@ -2346,19 +2364,44 @@ def _check_pysurfer_antialias(Brain): return kwargs +def _check_views(surf, views, hemi, stc=None, backend=None): + from ..source_estimate import SourceEstimate + _validate_type(views, (list, tuple, str), 'views') + views = [views] if isinstance(views, str) else list(views) + if surf == 'flat': + _check_option('views', views, (['auto'], ['flat'])) + views = ['flat'] + elif len(views) == 1 and views[0] == 'auto': + views = ['lateral'] + if views == ['flat']: + if stc is not None: + _validate_type(stc, SourceEstimate, 'stc', + 'SourceEstimate when a flatmap is used') + if backend is not None: + if backend != 'pyvista': + raise RuntimeError('The PyVista 3D backend must be used to ' + 'plot a flatmap') + if (views == ['flat']) ^ (surf == 'flat'): # exactly only one of the two + raise ValueError('surface="flat" must be used with views="flat", got ' + f'surface={repr(surf)} and views={repr(views)}') + return views + + @verbose def plot_vector_source_estimates(stc, subject=None, hemi='lh', colormap='hot', time_label='auto', smoothing_steps=10, transparent=None, brain_alpha=0.4, overlay_alpha=None, vector_alpha=1.0, scale_factor=None, time_viewer='auto', - subjects_dir=None, figure=None, views='lat', + subjects_dir=None, figure=None, + views='lateral', colorbar=True, clim='auto', cortex='classic', size=800, background='black', foreground=None, initial_time=None, time_unit='s', show_traces='auto', src=None, volume_options=1., - view_layout='vertical', verbose=None): + view_layout='vertical', + add_data_kwargs=None, verbose=None): """Plot VectorSourceEstimate with PySurfer. A "glass brain" is drawn and all dipoles defined in the source estimate @@ -2406,8 +2449,7 @@ def plot_vector_source_estimates(stc, subject=None, hemi='lh', colormap='hot', split view is requested, this must be a list of the appropriate length. If int is provided it will be used to identify the Mayavi figure by it's id or create a new figure with the given id. - views : str | list - View to use. See `surfer.Brain`. + %(views)s colorbar : bool If True, display colorbar on scene. %(clim_onesided)s @@ -2433,6 +2475,7 @@ def plot_vector_source_estimates(stc, subject=None, hemi='lh', colormap='hot', milliseconds ("ms"). %(show_traces)s %(src_volume_options_layout)s + %(add_data_kwargs)s %(verbose)s Returns @@ -2459,7 +2502,8 @@ def plot_vector_source_estimates(stc, subject=None, hemi='lh', colormap='hot', brain_alpha=brain_alpha, overlay_alpha=overlay_alpha, vector_alpha=vector_alpha, cortex=cortex, foreground=foreground, size=size, scale_factor=scale_factor, show_traces=show_traces, - src=src, volume_options=volume_options, view_layout=view_layout) + src=src, volume_options=volume_options, view_layout=view_layout, + add_data_kwargs=add_data_kwargs) @verbose diff --git a/mne/viz/_brain/_brain.py b/mne/viz/_brain/_brain.py index c0fd1fd061d..ac2c4bb8341 100644 --- a/mne/viz/_brain/_brain.py +++ b/mne/viz/_brain/_brain.py @@ -17,7 +17,7 @@ from .surface import Surface from .view import views_dicts -from .._3d import _process_clim, _handle_time +from .._3d import _process_clim, _handle_time, _check_views from ...defaults import _handle_default from ...surface import mesh_edges @@ -154,13 +154,15 @@ class _Brain(object): +---------------------------+--------------+-----------------------+ | view_layout | | ✓ | +---------------------------+--------------+-----------------------+ + | flatmaps | | ✓ | + +---------------------------+--------------+-----------------------+ """ def __init__(self, subject_id, hemi, surf, title=None, cortex="classic", alpha=1.0, size=800, background="black", foreground=None, figure=None, subjects_dir=None, - views=['lateral'], offset=True, show_toolbar=False, + views='auto', offset=True, show_toolbar=False, offscreen=False, interaction='trackball', units='mm', view_layout='vertical', show=True): from ..backends.renderer import backend, _get_renderer, _get_3d_backend @@ -196,6 +198,7 @@ def __init__(self, subject_id, hemi, surf, title=None, if isinstance(views, str): views = [views] + views = _check_views(surf, views, hemi) col_dict = dict(lh=1, rh=1, both=1, split=2) shape = (len(views), col_dict[hemi]) if self._view_layout == 'horizontal': @@ -291,6 +294,10 @@ def __init__(self, subject_id, hemi, surf, title=None, self._closed = False if show: self._renderer.show() + # update the views once the geometry is all set + for h in self._hemis: + for ri, ci, v in self._iter_views(h): + self.show_view(v, row=ri, col=ci, hemi=h) @property def interaction(self): @@ -326,7 +333,8 @@ def add_data(self, array, fmin=None, fmid=None, fmax=None, time_label="auto", colorbar=True, hemi=None, remove_existing=None, time_label_size=None, initial_time=None, scale_factor=None, vector_alpha=None, - clim=None, src=None, volume_options=0.4, verbose=None): + clim=None, src=None, volume_options=0.4, colorbar_kwargs=None, + verbose=None): """Display data from a numpy array on the surface or volume. This provides a similar interface to @@ -384,7 +392,8 @@ def add_data(self, array, fmin=None, fmid=None, fmax=None, time points in the data array (if data is 2D or 3D) %(time_label)s colorbar : bool - whether to add a colorbar to the figure + whether to add a colorbar to the figure. Can also be a tuple + to give the (row, col) index of where to put the colorbar. hemi : str | None If None, it is assumed to belong to the hemisphere being shown. If two hemispheres are being shown, an error will @@ -409,6 +418,9 @@ def add_data(self, array, fmin=None, fmid=None, fmax=None, clim : dict Original clim arguments. %(src_volume_options_layout)s + colorbar_kwargs : dict | None + Options to pass to :meth:`pyvista.BasePlotter.add_scalar_bar` + (e.g., ``dict(title_font_size=10)``). %(verbose)s Notes @@ -430,7 +442,12 @@ def add_data(self, array, fmin=None, fmid=None, fmax=None, # those parameters are not supported yet, only None is allowed _check_option('thresh', thresh, [None]) _check_option('remove_existing', remove_existing, [None]) - _check_option('time_label_size', time_label_size, [None]) + _validate_type(time_label_size, (None, 'numeric'), 'time_label_size') + if time_label_size is not None: + time_label_size = float(time_label_size) + if time_label_size < 0: + raise ValueError('time_label_size must be positive, got ' + f'{time_label_size}') hemi = self._check_hemi(hemi, extras=['vol']) array = np.asarray(array) @@ -479,6 +496,8 @@ def add_data(self, array, fmin=None, fmid=None, fmax=None, fmin, fmid, fmax = _update_limits( fmin, fmid, fmax, center, array ) + if colormap == 'auto': + colormap = 'mne' if center is not None else 'hot' if smoothing_steps is None: smoothing_steps = 7 @@ -538,10 +557,13 @@ def add_data(self, array, fmin=None, fmid=None, fmax=None, self.set_data_smoothing(self._data['smoothing_steps']) # 3) add the other actors + if colorbar is True: + # botto left by default + colorbar = (self._subplot_shape[0] - 1, 0) for ri, ci, v in self._iter_views(hemi): self._renderer.subplot(ri, ci) # Add the time label to the bottommost view - do = (ri == self._subplot_shape[0] - 1) + do = (ri, ci) == colorbar if not self._time_label_added and time_label is not None and do: time_actor = self._renderer.text2d( x_window=0.95, y_window=y_txt, @@ -553,9 +575,10 @@ def add_data(self, array, fmin=None, fmid=None, fmax=None, self._data['time_actor'] = time_actor self._time_label_added = True if colorbar and not self._colorbar_added and do: - self._renderer.scalarbar(source=actor, n_labels=8, - color=self._fg_color, - bgcolor=self._brain_color[:3]) + kwargs = dict(source=actor, n_labels=8, color=self._fg_color, + bgcolor=self._brain_color[:3]) + kwargs.update(colorbar_kwargs or {}) + self._renderer.scalarbar(**kwargs) self._colorbar_added = True self._renderer.set_camera(**views_dicts[hemi][v]) self._update() @@ -1219,10 +1242,9 @@ def update_lut(self, fmin=None, fmid=None, fmax=None): # update our values rng = self._cmap_range ctable = self._data['ctable'] - if self._colorbar_added: - scalar_bar = self._renderer.plotter.scalar_bar - else: - scalar_bar = None + # in testing, no plotter; if colorbar=False, no scalar_bar + scalar_bar = getattr( + getattr(self._renderer, 'plotter', None), 'scalar_bar', None) for hemi in ['lh', 'rh', 'vol']: hemi_data = self._data.get(hemi) if hemi_data is not None: @@ -1271,7 +1293,7 @@ def set_data_smoothing(self, n_steps): maps = sparse.eye(len(self.geo[hemi].coords), format='csr') with use_log_level(False): smooth_mat = _hemi_morph( - self.geo[hemi].faces, + self.geo[hemi].orig_faces, np.arange(len(self.geo[hemi].coords)), vertices, morph_n_steps, maps, warn=False) self._data[hemi]['smooth_mat'] = smooth_mat @@ -1645,16 +1667,18 @@ def _to_borders(self, label, hemi, borders, restrict_idx=None): raise ValueError('borders must be a bool or positive integer') if borders: n_vertices = label.size - edges = mesh_edges(self.geo[hemi].faces) + edges = mesh_edges(self.geo[hemi].orig_faces) edges = edges.tocoo() border_edges = label[edges.row] != label[edges.col] show = np.zeros(n_vertices, dtype=np.int64) keep_idx = np.unique(edges.row[border_edges]) if isinstance(borders, int): for _ in range(borders): - keep_idx = np.in1d(self.geo[hemi].faces.ravel(), keep_idx) - keep_idx.shape = self.geo[hemi].faces.shape - keep_idx = self.geo[hemi].faces[np.any(keep_idx, axis=1)] + keep_idx = np.in1d( + self.geo[hemi].orig_faces.ravel(), keep_idx) + keep_idx.shape = self.geo[hemi].orig_faces.shape + keep_idx = self.geo[hemi].orig_faces[ + np.any(keep_idx, axis=1)] keep_idx = np.unique(keep_idx) if restrict_idx is not None: keep_idx = keep_idx[np.in1d(keep_idx, restrict_idx)] diff --git a/mne/viz/_brain/_timeviewer.py b/mne/viz/_brain/_timeviewer.py index eb0e690e4b1..32b5ab049a6 100644 --- a/mne/viz/_brain/_timeviewer.py +++ b/mne/viz/_brain/_timeviewer.py @@ -429,13 +429,16 @@ def ensure_minimum_sizes(self): yield finally: self.splitter.setSizes([sz[1], mpl_h]) + _process_events(self.plotter) + _process_events(self.plotter) self.mpl_canvas.canvas.setMinimumSize(0, 0) _process_events(self.plotter) + _process_events(self.plotter) + # sizes could change, update views for hemi in ('lh', 'rh'): - if hemi == 'rh' and self.brain._hemi == 'split': - continue for ri, ci, v in self.brain._iter_views(hemi): self.brain.show_view(view=v, row=ri, col=ci) + _process_events(self.plotter) def toggle_interface(self, value=None): if value is None: @@ -660,6 +663,9 @@ def configure_sliders(self): for hemi in hemis_ref: for ri, ci, view in self.brain._iter_views(hemi): self.plotter.subplot(ri, ci) + if view == 'flat': + self.orientation_call = None + continue self.orientation_call = ShowView( plotter=self.plotter, brain=self.brain, @@ -1297,9 +1303,10 @@ def clean(self): self.reps = None self._time_slider = None self._playback_speed_slider = None - self.orientation_call.plotter = None - self.orientation_call.brain = None - self.orientation_call = None + if self.orientation_call is not None: + self.orientation_call.plotter = None + self.orientation_call.brain = None + self.orientation_call = None self.smoothing_call.plotter = None self.smoothing_call = None if self.time_call is not None: diff --git a/mne/viz/_brain/colormap.py b/mne/viz/_brain/colormap.py index 23f1fbdb5ab..d5de54c7e7c 100644 --- a/mne/viz/_brain/colormap.py +++ b/mne/viz/_brain/colormap.py @@ -11,6 +11,7 @@ def create_lut(cmap, n_colors=256, center=None): """Return a colormap suitable for setting as a LUT.""" from matplotlib import cm + assert not (isinstance(cmap, str) and cmap == 'auto') cmap = cm.get_cmap(cmap) lut = np.round(cmap(np.linspace(0, 1, n_colors)) * 255.0).astype(np.int64) return lut diff --git a/mne/viz/_brain/surface.py b/mne/viz/_brain/surface.py index de5b9cfde03..9bf533f2ac2 100644 --- a/mne/viz/_brain/surface.py +++ b/mne/viz/_brain/surface.py @@ -9,8 +9,9 @@ from os import path as path import numpy as np -from ...utils import _check_option, get_subjects_dir -from ...surface import complete_surface_info, read_surface, read_curvature +from ...utils import _check_option, get_subjects_dir, _check_fname +from ...surface import (complete_surface_info, read_surface, read_curvature, + _read_patch) class Surface(object): @@ -108,9 +109,17 @@ def load_geometry(self): ------- None """ - surf_path = path.join(self.data_path, 'surf', - '%s.%s' % (self.hemi, self.surf)) - coords, faces = read_surface(surf_path) + if self.surf == 'flat': # special case + fname = path.join(self.data_path, 'surf', + '%s.%s' % (self.hemi, 'cortex.patch.flat')) + _check_fname(fname, overwrite='read', must_exist=True, + name='flatmap surface file') + coords, faces, orig_faces = _read_patch(fname) + else: + coords, faces = read_surface( + path.join(self.data_path, 'surf', + '%s.%s' % (self.hemi, self.surf))) + orig_faces = faces if self.units == 'm': coords /= 1000. if self.offset is not None: @@ -121,15 +130,10 @@ def load_geometry(self): surf = dict(rr=coords, tris=faces) complete_surface_info(surf, copy=False, verbose=False) nn = surf['nn'] - - if self.coords is None: - self.coords = coords - self.faces = faces - self.nn = nn - else: - self.coords[:] = coords - self.faces[:] = faces - self.nn[:] = nn + self.coords = coords + self.faces = faces + self.orig_faces = orig_faces + self.nn = nn def __len__(self): """Return number of vertices.""" diff --git a/mne/viz/_brain/tests/test_brain.py b/mne/viz/_brain/tests/test_brain.py index 2ecbc1dfd1f..4e93000c765 100644 --- a/mne/viz/_brain/tests/test_brain.py +++ b/mne/viz/_brain/tests/test_brain.py @@ -138,14 +138,17 @@ def test_brain_init(renderer, tmpdir, pixel_ratio): with pytest.raises(ValueError, match='remove_existing'): brain.add_data(hemi_data, hemi=h, remove_existing=-1) with pytest.raises(ValueError, match='time_label_size'): - brain.add_data(hemi_data, hemi=h, time_label_size=-1) + brain.add_data(hemi_data, hemi=h, time_label_size=-1, + vertices=hemi_vertices) with pytest.raises(ValueError, match='is positive'): - brain.add_data(hemi_data, hemi=h, smoothing_steps=-1) + brain.add_data(hemi_data, hemi=h, smoothing_steps=-1, + vertices=hemi_vertices) with pytest.raises(TypeError, match='int or NoneType'): brain.add_data(hemi_data, hemi=h, smoothing_steps='foo') - with pytest.raises(ValueError): - brain.add_data(array=np.array([0, 1, 2]), hemi=h) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match='dimension mismatch'): + brain.add_data(array=np.array([0, 1, 2]), hemi=h, + vertices=hemi_vertices) + with pytest.raises(ValueError, match='vertices parameter must not be'): brain.add_data(hemi_data, fmin=fmin, hemi=hemi, fmax=fmax, vertices=None) with pytest.raises(ValueError, match='has shape'): @@ -154,7 +157,7 @@ def test_brain_init(renderer, tmpdir, pixel_ratio): brain.add_data(hemi_data, fmin=fmin, hemi=h, fmax=fmax, colormap='hot', vertices=hemi_vertices, - smoothing_steps='nearest', colorbar=False, time=None) + smoothing_steps='nearest', colorbar=(0, 0), time=None) assert brain.data['lh']['array'] is hemi_data assert brain.views == ['lateral'] assert brain.hemis == ('lh',) @@ -220,8 +223,7 @@ def test_brain_init(renderer, tmpdir, pixel_ratio): if renderer._get_3d_backend() == 'mayavi': pixel_ratio = 1. # no HiDPI when using the testing backend want_size = np.array([size[0] * pixel_ratio, size[1] * pixel_ratio, 3]) - assert_allclose(img.shape, want_size, - atol=70 * pixel_ratio) # XXX undo once size is fixed + assert_allclose(img.shape, want_size) brain.close() diff --git a/mne/viz/_brain/view.py b/mne/viz/_brain/view.py index acba52030a3..5bbd0f73580 100644 --- a/mne/viz/_brain/view.py +++ b/mne/viz/_brain/view.py @@ -6,39 +6,45 @@ # # License: Simplified BSD +ORIGIN = (0., 0., 0.) + _lh_views_dict = { - 'lateral': dict(azimuth=180., elevation=90.), - 'medial': dict(azimuth=0., elevation=90.0), - 'rostral': dict(azimuth=90., elevation=90.), - 'caudal': dict(azimuth=270., elevation=90.), - 'dorsal': dict(azimuth=180., elevation=0.), - 'ventral': dict(azimuth=180., elevation=180.), - 'frontal': dict(azimuth=120., elevation=80.), - 'parietal': dict(azimuth=-120., elevation=60.), - 'sagittal': dict(azimuth=180., elevation=-90.), - 'coronal': dict(azimuth=90., elevation=-90.), - 'axial': dict(azimuth=180., elevation=0., roll=180), + 'lateral': dict(azimuth=180., elevation=90., focalpoint=ORIGIN), + 'medial': dict(azimuth=0., elevation=90.0, focalpoint=ORIGIN), + 'rostral': dict(azimuth=90., elevation=90., focalpoint=ORIGIN), + 'caudal': dict(azimuth=270., elevation=90., focalpoint=ORIGIN), + 'dorsal': dict(azimuth=180., elevation=0., focalpoint=ORIGIN), + 'ventral': dict(azimuth=180., elevation=180., focalpoint=ORIGIN), + 'frontal': dict(azimuth=120., elevation=80., focalpoint=ORIGIN), + 'parietal': dict(azimuth=-120., elevation=60., focalpoint=ORIGIN), + 'sagittal': dict(azimuth=180., elevation=-90., focalpoint=ORIGIN), + 'coronal': dict(azimuth=90., elevation=-90., focalpoint=ORIGIN), + 'axial': dict(azimuth=180., elevation=0., focalpoint=ORIGIN, + roll=180), } _rh_views_dict = { - 'lateral': dict(azimuth=180., elevation=-90.), - 'medial': dict(azimuth=0., elevation=-90.0), - 'rostral': dict(azimuth=-90., elevation=-90.), - 'caudal': dict(azimuth=90., elevation=-90.), - 'dorsal': dict(azimuth=180., elevation=0.), - 'ventral': dict(azimuth=180., elevation=180.), - 'frontal': dict(azimuth=60., elevation=80.), - 'parietal': dict(azimuth=-60., elevation=60.), - 'sagittal': dict(azimuth=180., elevation=-90.), - 'coronal': dict(azimuth=90., elevation=-90.), - 'axial': dict(azimuth=180., elevation=0., roll=180), + 'lateral': dict(azimuth=180., elevation=-90., focalpoint=ORIGIN), + 'medial': dict(azimuth=0., elevation=-90.0, focalpoint=ORIGIN), + 'rostral': dict(azimuth=-90., elevation=-90., focalpoint=ORIGIN), + 'caudal': dict(azimuth=90., elevation=-90., focalpoint=ORIGIN), + 'dorsal': dict(azimuth=180., elevation=0., focalpoint=ORIGIN), + 'ventral': dict(azimuth=180., elevation=180., focalpoint=ORIGIN), + 'frontal': dict(azimuth=60., elevation=80., focalpoint=ORIGIN), + 'parietal': dict(azimuth=-60., elevation=60., focalpoint=ORIGIN), + 'sagittal': dict(azimuth=180., elevation=-90., focalpoint=ORIGIN), + 'coronal': dict(azimuth=90., elevation=-90., focalpoint=ORIGIN), + 'axial': dict(azimuth=180., elevation=0., focalpoint=ORIGIN, + roll=180), } # add short-size version entries into the dict lh_views_dict = _lh_views_dict.copy() for k, v in _lh_views_dict.items(): lh_views_dict[k[:3]] = v + lh_views_dict['flat'] = dict(azimuth=250, elevation=0, focalpoint=ORIGIN) rh_views_dict = _rh_views_dict.copy() for k, v in _rh_views_dict.items(): rh_views_dict[k[:3]] = v + rh_views_dict['flat'] = dict(azimuth=-70, elevation=0, focalpoint=ORIGIN) views_dicts = dict(lh=lh_views_dict, vol=lh_views_dict, both=lh_views_dict, rh=rh_views_dict) diff --git a/mne/viz/backends/_pyvista.py b/mne/viz/backends/_pyvista.py index c103307dc0f..48e27255bb8 100644 --- a/mne/viz/backends/_pyvista.py +++ b/mne/viz/backends/_pyvista.py @@ -198,11 +198,24 @@ def ensure_minimum_sizes(self): # plotter.ren_win: vtkXOpenGLRenderWindow self.plotter.interactor.setMinimumSize(*sz) try: - yield + yield # show finally: - for _ in range(2): - _process_events(self.plotter) + # 1. Process events + _process_events(self.plotter) + _process_events(self.plotter) + # 2. Get the window size that accommodates the size + sz = self.plotter.app_window.size() + # 3. Call app_window.setBaseSize and resize (in pyvistaqt) + self.plotter.window_size = (sz.width(), sz.height()) + # 4. Undo the min size setting and process events self.plotter.interactor.setMinimumSize(0, 0) + _process_events(self.plotter) + _process_events(self.plotter) + # 5. Resize the window (again!) to the correct size + # (not sure why, but this is required on macOS at least) + self.plotter.window_size = (sz.width(), sz.height()) + _process_events(self.plotter) + _process_events(self.plotter) def subplot(self, x, y): x = np.max([0, np.min([x, self.shape[0] - 1])]) @@ -510,25 +523,22 @@ def text3d(self, x, y, z, text, scale, color='white'): shape_opacity=0) def scalarbar(self, source, color="white", title=None, n_labels=4, - bgcolor=None): + bgcolor=None, **extra_kwargs): with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=FutureWarning) - self.plotter.add_scalar_bar(color=color, title=title, - n_labels=n_labels, - use_opacity=False, n_colors=256, - position_x=0.15, - position_y=0.05, width=0.7, - shadow=False, bold=True, - label_font_size=22, - font_family=self.font_family, - background_color=bgcolor) + kwargs = dict(color=color, title=title, n_labels=n_labels, + use_opacity=False, n_colors=256, position_x=0.15, + position_y=0.05, width=0.7, shadow=False, bold=True, + label_font_size=22, font_family=self.font_family, + background_color=bgcolor) + kwargs.update(extra_kwargs) + self.plotter.add_scalar_bar(**kwargs) def show(self): self.figure.display = self.plotter.show() if hasattr(self.plotter, "app_window"): with self.ensure_minimum_sizes(): self.plotter.app_window.show() - _process_events(self.plotter, show=True) return self.scene() def close(self): @@ -663,7 +673,8 @@ def _get_camera_direction(focalpoint, position): def _set_3d_view(figure, azimuth, elevation, focalpoint, distance, roll=None): position = np.array(figure.plotter.camera_position[0]) - focalpoint = np.array(figure.plotter.camera_position[1]) + if focalpoint is None: + focalpoint = np.array(figure.plotter.camera_position[1]) r, theta, phi, fp = _get_camera_direction(focalpoint, position) if azimuth is not None: @@ -739,13 +750,11 @@ def _take_3d_screenshot(figure, mode='rgb', filename=None): filename=filename) -def _process_events(plotter, show=False): +def _process_events(plotter): if hasattr(plotter, 'app'): with warnings.catch_warnings(record=True): warnings.filterwarnings('ignore', 'constrained_layout') plotter.app.processEvents() - if show: - plotter.app_window.show() def _set_colormap_range(actor, ctable, scalar_bar, rng=None): diff --git a/mne/viz/tests/test_3d.py b/mne/viz/tests/test_3d.py index 06ebaa5290a..029b9c18e2c 100644 --- a/mne/viz/tests/test_3d.py +++ b/mne/viz/tests/test_3d.py @@ -705,9 +705,10 @@ def test_plot_source_estimates(renderer_interactive, all_src_types_inv_evoked, with pytest.warns(None): # PCA mag stc = apply_inverse(evoked, inv, pick_ori=pick_ori) stc.data[1] *= -1 # make it signed - meth = 'plot_3d' if isinstance(stc, _BaseVolSourceEstimate) else 'plot' - meth = getattr(stc, meth) - kwargs = dict(subject='sample', subjects_dir=subjects_dir, + meth_key = 'plot_3d' if isinstance(stc, _BaseVolSourceEstimate) else 'plot' + stc.subject = 'sample' + meth = getattr(stc, meth_key) + kwargs = dict(subjects_dir=subjects_dir, time_viewer=False, show_traces=False, # for speed smoothing_steps=1, verbose='error', src=inv['src'], volume_options=dict(resolution=None), # for speed @@ -740,6 +741,31 @@ def test_plot_source_estimates(renderer_interactive, all_src_types_inv_evoked, with pytest.raises(ValueError, match='view_layout must be'): meth(view_layout='horizontal', **kwargs) + # flatmaps (mostly a lot of error checking) + these_kwargs = kwargs.copy() + these_kwargs.update(surface='flat', views='auto') + if kind == 'surface' and pick_ori != 'vector' and is_pyvista: + with pytest.raises(FileNotFoundError, match='flatmap'): + meth(**these_kwargs) # sample does not have them + fs_stc = stc.copy() + fs_stc.subject = 'fsaverage' # this is wrong, but don't have to care + flat_meth = getattr(fs_stc, meth_key) + these_kwargs.pop('src') + if pick_ori == 'vector': + pass # can't even pass "surface" variable + elif kind != 'surface': + with pytest.raises(TypeError, match='SourceEstimate when a flatmap'): + flat_meth(**these_kwargs) + elif not is_pyvista: + with pytest.raises(RuntimeError, match='PyVista 3D backend.*flatmap'): + flat_meth(**these_kwargs) + else: + brain = flat_meth(**these_kwargs) + brain.close() + these_kwargs.update(surface='inflated', views='flat') + with pytest.raises(ValueError, match='surface="flat".*views="flat"'): + flat_meth(**these_kwargs) + # just test one for speed if kind != 'mixed': return diff --git a/tutorials/source-modeling/plot_visualize_stc.py b/tutorials/source-modeling/plot_visualize_stc.py index 2d530b9f80c..85e8cca8f21 100644 --- a/tutorials/source-modeling/plot_visualize_stc.py +++ b/tutorials/source-modeling/plot_visualize_stc.py @@ -50,7 +50,21 @@ # and ``pysurfer`` installed on your machine. initial_time = 0.1 brain = stc.plot(subjects_dir=subjects_dir, initial_time=initial_time, - clim=dict(kind='value', pos_lims=[3, 6, 9])) + clim=dict(kind='value', lims=[3, 6, 9])) + +############################################################################### +# You can also morph it to fsaverage and visualize it using a flatmap: + +# sphinx_gallery_thumbnail_number = 2 + +stc_fs = mne.compute_source_morph(stc, 'sample', 'fsaverage', subjects_dir, + smooth=5, verbose='error').apply(stc) +brain = stc_fs.plot(subjects_dir=subjects_dir, initial_time=initial_time, + clim=dict(kind='value', lims=[3, 6, 9]), + surface='flat', hemi='split', size=(1000, 500), + smoothing_steps=5, time_viewer=False, + add_data_kwargs=dict( + colorbar_kwargs=dict(label_font_size=10))) ############################################################################### #