Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MRG, MAINT: Unify 90% similar code #8068

Merged
merged 1 commit into from
Jul 29, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mne/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
_check_rank, _check_option, _check_depth, _check_combine,
_check_path_like, _check_src_normal, _check_stc_units,
_check_pyqt5_version, _check_sphere, _check_time_format,
_check_freesurfer_home, _suggest)
_check_freesurfer_home, _suggest, _require_version)
from .config import (set_config, get_config, get_config_path, set_cache_dir,
set_memmap_min_size, get_subjects_dir, _get_stim_channel,
sys_info, _get_extra_data_path, _get_root_dir,
Expand Down
305 changes: 131 additions & 174 deletions mne/viz/_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
#
# License: Simplified BSD

from distutils.version import LooseVersion
from itertools import cycle
import os.path as op
import sys
Expand Down Expand Up @@ -39,7 +38,8 @@
read_ras_mni_t, _print_coord_trans)
from ..utils import (get_subjects_dir, logger, _check_subject, verbose, warn,
has_nibabel, check_version, fill_doc, _pl, get_config,
_ensure_int, _validate_type, _check_option)
_ensure_int, _validate_type, _check_option,
_require_version)
from .utils import (mne_analyze_colormap, _get_color_list,
plt_show, tight_layout, figure_nobar, _check_time_unit)
from .misc import _check_mri
Expand Down Expand Up @@ -1675,32 +1675,62 @@ def plot_source_estimates(stc, subject=None, surface='inflated', hemi='lh',
plot_mpl = True
else: # 'mayavi'
raise

kwargs = dict(
subject=subject, surface=surface, hemi=hemi, colormap=colormap,
time_label=time_label, smoothing_steps=smoothing_steps,
subjects_dir=subjects_dir, views=views, clim=clim,
figure=figure, initial_time=initial_time, time_unit=time_unit,
background=background, time_viewer=time_viewer, colorbar=colorbar,
transparent=transparent)
if plot_mpl:
return _plot_mpl_stc(stc, subject=subject, surface=surface, hemi=hemi,
colormap=colormap, time_label=time_label,
smoothing_steps=smoothing_steps,
subjects_dir=subjects_dir, views=views, clim=clim,
figure=figure, initial_time=initial_time,
time_unit=time_unit, background=background,
spacing=spacing, time_viewer=time_viewer,
colorbar=colorbar, transparent=transparent)

if _get_3d_backend() == "mayavi":
return _plot_mpl_stc(stc, spacing=spacing, **kwargs)
return _plot_stc(
stc, overlay_alpha=alpha, brain_alpha=alpha, vector_alpha=alpha,
cortex=cortex, foreground=foreground, size=size, scale_factor=None,
show_traces=show_traces, **kwargs)


def _plot_stc(stc, subject, surface, hemi, colormap, time_label,
smoothing_steps, subjects_dir, views, clim, figure, initial_time,
time_unit, background, time_viewer, colorbar, transparent,
brain_alpha, overlay_alpha, vector_alpha, cortex, foreground,
size, scale_factor, show_traces):
from .backends.renderer import _get_3d_backend
from ..source_estimate import (
_BaseSourceEstimate, SourceEstimate, VectorSourceEstimate)
_validate_type(stc, _BaseSourceEstimate)
vec = stc._data_ndim == 3
if vec:
allowed = VectorSourceEstimate
else:
allowed = SourceEstimate
_validate_type(stc, allowed, 'stc')
subjects_dir = get_subjects_dir(subjects_dir=subjects_dir,
raise_error=True)
subject = _check_subject(stc.subject, subject, True)

backend = _get_3d_backend()
del _get_3d_backend
using_mayavi = backend == "mayavi"
if using_mayavi:
from surfer import Brain
_require_version('surfer', 'stc.plot', '0.9')
else: # PyVista
from ._brain import _Brain as Brain
_check_option('hemi', hemi, ['lh', 'rh', 'split', 'both'])

_check_option('hemi', hemi, ['lh', 'rh', 'split', 'both'])
time_label, times = _handle_time(time_label, time_unit, stc.times)

# convert control points to locations in colormap
mapdata = _process_clim(clim, colormap, transparent, stc.data)
mapdata = _process_clim(clim, colormap, transparent, stc.data,
allow_pos_lims=not vec)

# XXX we should only need to do this for PySurfer/Mayavi, the PyVista
# plotter should be smart enough to do this separation in the cmap-to-ctab
# conversion. But this will need to be another refactoring that will
# hopefully restore this line:
#
# if _get_3d_backend() == 'mayavi':
# if using_mayavi:
_separate_map(mapdata)
colormap = mapdata['colormap']
diverging = 'pos_lims' in mapdata['clim']
Expand All @@ -1713,59 +1743,94 @@ def plot_source_estimates(stc, subject=None, surface='inflated', hemi='lh',
else:
hemis = [hemi]

if title is None:
title = subject if len(hemis) > 1 else '%s - %s' % (subject, hemis[0])
if overlay_alpha is None:
overlay_alpha = brain_alpha
if overlay_alpha == 0:
smoothing_steps = 1 # Disable smoothing to save time.

title = subject if len(hemis) > 1 else '%s - %s' % (subject, hemis[0])
kwargs = {
"subject_id": subject, "hemi": hemi, "surf": surface,
"title": title, "cortex": cortex, "size": size,
"background": background, "foreground": foreground,
"figure": figure, "subjects_dir": subjects_dir,
"views": views,
"views": views, "alpha": brain_alpha,
}
if _get_3d_backend() in ['pyvista', 'notebook']:
if backend in ['pyvista', 'notebook']:
kwargs["show"] = not time_viewer
else:
kwargs.update(_check_pysurfer_antialias(Brain))
with warnings.catch_warnings(record=True): # traits warnings
brain = Brain(**kwargs)
del kwargs
if scale_factor is None:
# Configure the glyphs scale directly
width = np.mean([np.ptp(brain.geo[hemi].coords[:, 1])
for hemi in hemis if hemi in brain.geo])
scale_factor = 0.025 * width / scale_pts[-1]

if transparent is None:
transparent = True
sd_kwargs = dict(transparent=transparent, verbose=False)
center = 0. if diverging else None
for hemi in hemis:
hemi_idx = 0 if hemi == 'lh' else 1
data = getattr(stc, hemi + '_data')
vertices = stc.vertices[hemi_idx]
if len(data) > 0:
if transparent is None:
transparent = True
kwargs = {
"array": data, "colormap": colormap,
"vertices": vertices,
"smoothing_steps": smoothing_steps,
"time": times, "time_label": time_label,
"alpha": alpha, "hemi": hemi,
"colorbar": colorbar, "initial_time": initial_time,
"transparent": transparent, "center": center,
"verbose": False
}
if _get_3d_backend() == "mayavi":
kwargs["min"] = scale_pts[0]
kwargs["mid"] = scale_pts[1]
kwargs["max"] = scale_pts[2]
else: # pyvista
kwargs["fmin"] = scale_pts[0]
kwargs["fmid"] = scale_pts[1]
kwargs["fmax"] = scale_pts[2]
kwargs["clim"] = clim
with warnings.catch_warnings(record=True): # traits warnings
brain.add_data(**kwargs)

_check_time_viewer_compatibility(brain, time_viewer, show_traces)
return brain
vertices = stc.vertices[0 if hemi == 'lh' else 1]
alpha = overlay_alpha
if len(data) == 0:
continue
kwargs = {
"array": data, "colormap": colormap,
"vertices": vertices,
"smoothing_steps": smoothing_steps,
"time": times, "time_label": time_label,
"alpha": alpha, "hemi": hemi,
"colorbar": colorbar,
"vector_alpha": vector_alpha,
"scale_factor": scale_factor,
"verbose": False,
"initial_time": initial_time,
"transparent": transparent, "center": center,
"verbose": False
}
if using_mayavi:
kwargs["min"] = scale_pts[0]
kwargs["mid"] = scale_pts[1]
kwargs["max"] = scale_pts[2]
else: # pyvista
kwargs["fmin"] = scale_pts[0]
kwargs["fmid"] = scale_pts[1]
kwargs["fmax"] = scale_pts[2]
kwargs["clim"] = clim
with warnings.catch_warnings(record=True): # traits warnings
brain.add_data(**kwargs)
brain.scale_data_colormap(fmin=scale_pts[0], fmid=scale_pts[1],
fmax=scale_pts[2], **sd_kwargs)

need_peeling = (brain_alpha < 1.0 and
sys.platform != 'darwin' and
vec)
if using_mayavi:
for hemi in hemis:
for b in brain._brain_list:
for layer in b['brain'].data.values():
glyphs = layer['glyphs']
if glyphs is None:
continue
glyphs.glyph.glyph.scale_factor = scale_factor
glyphs.glyph.glyph.clamping = False
glyphs.glyph.glyph.range = (0., 1.)

def _check_time_viewer_compatibility(brain, time_viewer, show_traces):
from .backends.renderer import _get_3d_backend
using_mayavi = _get_3d_backend() == "mayavi"
# depth peeling patch
if need_peeling:
for ff in brain._figures:
for f in ff:
if f.scene is not None and sys.platform != 'darwin':
f.scene.renderer.use_depth_peeling = True
elif need_peeling:
brain.enable_depth_peeling()

# time_viewer and show_traces
_check_option('time_viewer', time_viewer, (True, False, 'auto'))
_check_option('show_traces', show_traces,
(True, False, 'auto', 'separate'))
Expand All @@ -1778,18 +1843,11 @@ def _check_time_viewer_compatibility(brain, time_viewer, show_traces):
brain._times is not None and
len(brain._times) > 1
)

if _get_3d_backend() == "mayavi" and all([time_viewer, show_traces]):
raise NotImplementedError("Point picking is not available"
" for the mayavi 3d backend.")
if using_mayavi:
if not check_version('surfer', '0.9'):
raise RuntimeError('This function requires pysurfer version '
'>= 0.9')

if show_traces and not time_viewer:
raise ValueError('show_traces cannot be used when time_viewer=False')

if using_mayavi and show_traces:
raise NotImplementedError("show_traces=True is not available "
"for the mayavi 3d backend.")
if time_viewer:
if using_mayavi:
from surfer import TimeViewer
Expand All @@ -1798,6 +1856,8 @@ def _check_time_viewer_compatibility(brain, time_viewer, show_traces):
from ._brain import _TimeViewer as TimeViewer
TimeViewer(brain, show_traces=show_traces)

return brain


def _glass_brain_crosshairs(params, x, y, z):
for ax, a, b in ((params['ax_y'], x, z),
Expand Down Expand Up @@ -2343,118 +2403,15 @@ def plot_vector_source_estimates(stc, subject=None, hemi='lh', colormap='hot',
If the current magnitude overlay is not desired, set ``overlay_alpha=0``
and ``smoothing_steps=1``.
"""
from .backends.renderer import _get_3d_backend
# Import here to avoid circular imports
if _get_3d_backend() == "mayavi":
from surfer import Brain
from surfer import __version__ as surfer_version
else: # PyVista
from ._brain import _Brain as Brain
from ..source_estimate import VectorSourceEstimate

_validate_type(stc, VectorSourceEstimate, "stc", "Vector Source Estimate")
subjects_dir = get_subjects_dir(subjects_dir=subjects_dir,
raise_error=True)
subject = _check_subject(stc.subject, subject, True)
_check_option('hemi', hemi, ['lh', 'rh', 'split', 'both'])
time_label, times = _handle_time(time_label, time_unit, stc.times)

# convert control points to locations in colormap
mapdata = _process_clim(clim, colormap, transparent, stc.data,
allow_pos_lims=False)
colormap = mapdata['colormap']
scale_pts = mapdata['clim']['lims'] # pos_lims not allowed
transparent = mapdata['transparent']
del mapdata

if hemi in ['both', 'split']:
hemis = ['lh', 'rh']
else:
hemis = [hemi]

if overlay_alpha is None:
overlay_alpha = brain_alpha
if overlay_alpha == 0:
smoothing_steps = 1 # Disable smoothing to save time.

title = subject if len(hemis) > 1 else '%s - %s' % (subject, hemis[0])
kwargs = {
"subject_id": subject, "hemi": hemi, "surf": 'white',
"title": title, "cortex": cortex, "size": size,
"background": background, "foreground": foreground,
"figure": figure, "subjects_dir": subjects_dir,
"views": views, "alpha": brain_alpha,
}
if _get_3d_backend() in ['pyvista', 'notebook']:
kwargs["show"] = not time_viewer
else:
kwargs.update(_check_pysurfer_antialias(Brain))
with warnings.catch_warnings(record=True): # traits warnings
brain = Brain(**kwargs)
del kwargs
if scale_factor is None:
# Configure the glyphs scale directly
width = np.mean([np.ptp(brain.geo[hemi].coords[:, 1])
for hemi in hemis if hemi in brain.geo])
scale_factor = 0.025 * width / scale_pts[-1]

sd_kwargs = dict(transparent=transparent, verbose=False)
for hemi in hemis:
hemi_idx = 0 if hemi == 'lh' else 1
data = getattr(stc, hemi + '_data')
vertices = stc.vertices[hemi_idx]
if len(data) > 0:
kwargs = {
"array": data, "colormap": colormap,
"vertices": vertices,
"smoothing_steps": smoothing_steps,
"time": times, "time_label": time_label,
"alpha": overlay_alpha, "hemi": hemi,
"colorbar": colorbar,
"vector_alpha": vector_alpha,
"scale_factor": scale_factor,
"verbose": False,
}
if initial_time is not None:
kwargs['initial_time'] = initial_time
if _get_3d_backend() == "mayavi":
if surfer_version >= LooseVersion('0.9'):
kwargs["transparent"] = transparent
kwargs["min"] = scale_pts[0]
kwargs["mid"] = scale_pts[1]
kwargs["max"] = scale_pts[2]
else:
kwargs["transparent"] = transparent
kwargs["fmin"] = scale_pts[0]
kwargs["fmid"] = scale_pts[1]
kwargs["fmax"] = scale_pts[2]
with warnings.catch_warnings(record=True): # traits warnings
brain.add_data(**kwargs)
brain.scale_data_colormap(fmin=scale_pts[0], fmid=scale_pts[1],
fmax=scale_pts[2], **sd_kwargs)

if _get_3d_backend() == "mayavi":
for hemi in hemis:
for b in brain._brain_list:
for layer in b['brain'].data.values():
glyphs = layer['glyphs']
glyphs.glyph.glyph.scale_factor = scale_factor
glyphs.glyph.glyph.clamping = False
glyphs.glyph.glyph.range = (0., 1.)

# depth peeling patch
if brain_alpha < 1.0:
for ff in brain._figures:
for f in ff:
if f.scene is not None and sys.platform != 'darwin':
f.scene.renderer.use_depth_peeling = True
else:
if brain_alpha < 1.0 and sys.platform != 'darwin':
brain.enable_depth_peeling()

_check_time_viewer_compatibility(brain, time_viewer, show_traces)

return brain
return _plot_stc(
stc, subject=subject, surface='white', hemi=hemi, colormap=colormap,
time_label=time_label, smoothing_steps=smoothing_steps,
subjects_dir=subjects_dir, views=views, clim=clim, figure=figure,
initial_time=initial_time, time_unit=time_unit, background=background,
time_viewer=time_viewer, colorbar=colorbar, transparent=transparent,
brain_alpha=brain_alpha, overlay_alpha=overlay_alpha,
vector_alpha=vector_alpha, cortex=cortex, foreground=foreground,
size=size, scale_factor=scale_factor, show_traces=show_traces)


@verbose
Expand Down