Skip to content

Commit

Permalink
Merge pull request #861 from ioam/bokeh_colorbars
Browse files Browse the repository at this point in the history
Bokeh colorbars
  • Loading branch information
jlstevens authored Sep 14, 2016
2 parents 4664863 + 9e2f1b6 commit 96be6a7
Show file tree
Hide file tree
Showing 5 changed files with 266 additions and 121 deletions.
64 changes: 25 additions & 39 deletions holoviews/plotting/bokeh/chart.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
from ...core.util import max_range, basestring, dimension_sanitizer
from ...core.options import abbreviated_exception
from ..util import compute_sizes, get_sideplot_ranges, match_spec, map_colors
from .element import ElementPlot, line_properties, fill_properties
from .element import ElementPlot, ColorbarPlot, line_properties, fill_properties
from .path import PathPlot, PolygonPlot
from .util import get_cmap, mpl_to_bokeh, update_plot, rgb2hex, bokeh_version


class PointPlot(ElementPlot):
class PointPlot(ColorbarPlot):

color_index = param.ClassSelector(default=3, class_=(basestring, int),
allow_None=True, doc="""
Expand Down Expand Up @@ -55,21 +55,12 @@ def get_data(self, element, ranges=None, empty=False):
mapping = dict(x=dims[xidx], y=dims[yidx])
data = {}

cmap = style.get('palette', style.get('cmap', None))
cdim = element.get_dimension(self.color_index)
if cdim and cmap:
map_key = 'color_' + cdim.name
mapping['color'] = map_key
if empty:
data[map_key] = []
else:
cmap = get_cmap(cmap)
colors = element.dimension_values(self.color_index)
if colors.dtype.kind in 'if':
crange = ranges.get(cdim.name, element.range(cdim.name))
else:
crange = np.unique(colors)
data[map_key] = map_colors(colors, crange, cmap)
if cdim:
mapper = self._get_colormapper(cdim, element, ranges, style)
data[cdim.name] = [] if empty else element.dimension_values(cdim)
mapping['color'] = {'field': cdim.name,
'transform': mapper}

sdim = element.get_dimension(self.size_index)
if sdim:
Expand Down Expand Up @@ -98,7 +89,7 @@ def get_batched_data(self, element, ranges=None, empty=False):
eldata, elmapping = self.get_data(el, ranges, empty)
for k, eld in eldata.items():
data[k].append(eld)
if 'color' not in eldata:
if 'color' not in elmapping:
zorder = self.get_zorder(element, key, el)
val = style[zorder].get('color')
elmapping['color'] = 'color'
Expand Down Expand Up @@ -128,6 +119,8 @@ def _init_glyph(self, plot, mapping, properties):
else:
plot_method = self._plot_methods.get('batched' if self.batched else 'single')
renderer = getattr(plot, plot_method)(**dict(properties, **mapping))
if self.colorbar and 'color_mapper' in self.handles:
self._draw_colorbar(plot, self.handles['color_mapper'])
return renderer, renderer.glyph


Expand Down Expand Up @@ -239,7 +232,7 @@ def get_data(self, element, ranges=None, empty=None):
return (data, mapping)


class SideHistogramPlot(HistogramPlot):
class SideHistogramPlot(HistogramPlot, ColorbarPlot):

style_opts = HistogramPlot.style_opts + ['cmap']

Expand All @@ -262,19 +255,20 @@ def get_data(self, element, ranges=None, empty=None):
data = dict(top=element.values, left=element.edges[:-1],
right=element.edges[1:])

dim = element.get_dimension(0).name
dim = element.get_dimension(0)
main = self.adjoined.main
range_item, main_range, dim = get_sideplot_ranges(self, element, main, ranges)
vals = element.dimension_values(dim)
range_item, main_range, _ = get_sideplot_ranges(self, element, main, ranges)
if isinstance(range_item, (Raster, Points, Polygons, Spikes)):
style = self.lookup_options(range_item, 'style')[self.cyclic_index]
else:
style = {}

if 'cmap' in style or 'palette' in style:
cmap = get_cmap(style.get('cmap', style.get('palette', None)))
data['color'] = [] if empty else map_colors(vals, main_range, cmap)
mapping['fill_color'] = 'color'
main_range = {dim.name: main_range}
cmapper = self._get_colormapper(dim, element, main_range, style)
data[dim.name] = [] if empty else element.dimension_values(dim)
mapping['fill_color'] = {'field': dim.name,
'transform': cmapper}
self._get_hover_data(data, element, empty)
return (data, mapping)

Expand Down Expand Up @@ -314,7 +308,7 @@ def get_data(self, element, ranges=None, empty=False):
return (data, dict(self._mapping))


class SpikesPlot(PathPlot):
class SpikesPlot(PathPlot, ColorbarPlot):

color_index = param.ClassSelector(default=1, class_=(basestring, int), doc="""
Index of the dimension from which the color will the drawn""")
Expand Down Expand Up @@ -352,22 +346,14 @@ def get_data(self, element, ranges=None, empty=False):
xs, ys = zip(*(((x[0], x[0]), (pos+height, pos))
for x in element.array(dims[:1])))

if not empty and self.invert_axes: keys = keys[::-1]
if not empty and self.invert_axes: xs, ys = ys, xs
data = dict(zip(('xs', 'ys'), (xs, ys)))

cmap = style.get('palette', style.get('cmap', None))
cdim = element.get_dimension(self.color_index)
if cdim and cmap:
map_key = 'color_' + cdim.name
mapping['color'] = map_key
if empty:
colors = []
else:
cmap = get_cmap(cmap)
cvals = element.dimension_values(cdim)
crange = ranges.get(cdim.name, None)
colors = map_colors(cvals, crange, cmap)
data[map_key] = colors
if cdim:
cmapper = self._get_colormapper(cdim, element, ranges, style)
data[cdim.name] = [] if empty else element.dimension_values(cdim)
mapping['color'] = {'field': cdim.name,
'transform': cmapper}

if 'hover' in self.tools+self.default_tools and not empty:
for d in dims:
Expand Down
185 changes: 162 additions & 23 deletions holoviews/plotting/bokeh/element.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,14 @@
from bokeh.models.tickers import Ticker, BasicTicker, FixedTicker
from bokeh.models.widgets import Panel, Tabs

from bokeh.models.mappers import LinearColorMapper
try:
from bokeh.models import ColorBar
from bokeh.models.mappers import LogColorMapper
except ImportError:
LogColorMapper, ColorBar = None, None
from bokeh.models import LogTicker, BasicTicker

try:
from bokeh import mpl
except ImportError:
Expand All @@ -22,7 +30,8 @@
from ..util import dynamic_update
from .callbacks import Callbacks
from .plot import BokehPlot
from .util import mpl_to_bokeh, convert_datetime, update_plot, bokeh_version
from .util import (mpl_to_bokeh, convert_datetime, update_plot,
bokeh_version, mplcmap_to_palette)

if bokeh_version >= '0.12':
from bokeh.models import FuncTickFormatter
Expand Down Expand Up @@ -104,6 +113,13 @@ class ElementPlot(BokehPlot, GenericElementPlot):
tools = param.List(default=[], doc="""
A list of plugin tools to use on the plot.""")

toolbar = param.ObjectSelector(default='right',
objects=["above", "below",
"left", "right", None],
doc="""
The toolbar location, must be one of 'above', 'below',
'left', 'right', None.""")

xaxis = param.ObjectSelector(default='bottom',
objects=['top', 'bottom', 'bare', 'top-bare',
'bottom-bare', None], doc="""
Expand Down Expand Up @@ -268,7 +284,6 @@ def _init_plot(self, key, element, plots, ranges=None):
axis_types, labels, plot_ranges = self._axes_props(plots, subplots, element, ranges)
xlabel, ylabel, _ = labels
x_axis_type, y_axis_type = axis_types
tools = self._init_tools(element)
properties = dict(plot_ranges)
properties['x_axis_label'] = xlabel if 'x' in self.show_labels else ' '
properties['y_axis_label'] = ylabel if 'y' in self.show_labels else ' '
Expand All @@ -278,10 +293,15 @@ def _init_plot(self, key, element, plots, ranges=None):
else:
title = ''

if self.toolbar:
tools = self._init_tools(element)
properties['tools'] = tools
properties['toolbar_location'] = self.toolbar

properties['webgl'] = Store.renderers[self.renderer.backend].webgl
return bokeh.plotting.Figure(x_axis_type=x_axis_type,
y_axis_type=y_axis_type, title=title,
tools=tools, **properties)
**properties)


def _plot_properties(self, key, plot, element):
Expand Down Expand Up @@ -618,6 +638,144 @@ def framewise(self):
for frame in current_frames)



class ColorbarPlot(ElementPlot):
"""
ColorbarPlot provides methods to create colormappers and colorbar
models which can be added to a glyph. Additionally it provides
parameters to control the position and other styling options of
the colorbar. The default colorbar_position options are defined
by the colorbar_specs, but may be overridden by the colorbar_opts.
"""

colorbar_specs = {'right': {'pos': 'right',
'opts': {'location': (0, 0)}},
'left': {'pos': 'left',
'opts':{'location':(0, 0)}},
'bottom': {'pos': 'below',
'opts': {'location': (0, 0),
'orientation':'horizontal'}},
'top': {'pos': 'above',
'opts': {'location':(0, 0),
'orientation':'horizontal'}},
'top_right': {'pos': 'center',
'opts': {'location': 'top_right'}},
'top_left': {'pos': 'center',
'opts': {'location': 'top_left'}},
'bottom_left': {'pos': 'center',
'opts': {'location': 'bottom_left',
'orientation': 'horizontal'}},
'bottom_right': {'pos': 'center',
'opts': {'location': 'bottom_right',
'orientation': 'horizontal'}}}

colorbar = param.Boolean(default=False, doc="""
Whether to display a colorbar.""")

colorbar_position = param.ObjectSelector(objects=list(colorbar_specs),
default="right", doc="""
Allows selecting between a number of predefined colorbar position
options. The predefined options may be customized in the
colorbar_specs class attribute.""")

colorbar_opts = param.Dict(default={}, doc="""
Allows setting specific styling options for the colorbar overriding
the options defined in the colorbar_specs class attribute. Includes
location, orientation, height, width, scale_alpha, title, title_props,
margin, padding, background_fill_color and more.""")

logz = param.Boolean(default=False, doc="""
Whether to apply log scaling to the z-axis.""")

_update_handles = ['color_mapper', 'source', 'glyph']

_colorbar_defaults = dict(bar_line_color='black', label_standoff=8,
major_tick_line_color='black')

def _draw_colorbar(self, plot, color_mapper):
if LogColorMapper and isinstance(color_mapper, LogColorMapper):
ticker = LogTicker()
else:
ticker = BasicTicker()
cbar_opts = dict(self.colorbar_specs[self.colorbar_position])

# Check if there is a colorbar in the same position
pos = cbar_opts['pos']
if any(isinstance(model, ColorBar) for model in getattr(plot, pos, [])):
return

opts = dict(cbar_opts['opts'], self._colorbar_defaults)
color_bar = ColorBar(color_mapper=color_mapper, ticker=ticker,
**dict(opts, **self.colorbar_opts))

plot.add_layout(color_bar, pos)
self.handles['colorbar'] = color_bar


def _get_colormapper(self, dim, element, ranges, style):
low, high = ranges.get(dim.name)
palette = mplcmap_to_palette(style.pop('cmap', 'viridis'))
colormapper = LogColorMapper if self.logz else LinearColorMapper
cmapper = colormapper(palette, low=low, high=high)

# The initial colormapper instance is cached the first time
# and then updated with the values from new instances
if 'color_mapper' not in self.handles:
self.handles['color_mapper'] = cmapper
return cmapper


def _init_glyph(self, plot, mapping, properties):
"""
Returns a Bokeh glyph object and optionally creates a colorbar.
"""
ret = super(ColorbarPlot, self)._init_glyph(plot, mapping, properties)
if self.colorbar and 'color_mapper' in self.handles:
self._draw_colorbar(plot, self.handles['color_mapper'])
return ret


def _update_glyph(self, glyph, properties, mapping):
allowed_properties = glyph.properties()
cmappers = [v.get('transform') for v in mapping.values()
if isinstance(v, dict)]
cmappers.append(properties.pop('color_mapper', None))
for cm in cmappers:
if cm:
self.handles['color_mapper'].low = cm.low
self.handles['color_mapper'].high = cm.high
self.handles['color_mapper'].palette = cm.palette
merged = dict(properties, **mapping)
glyph.set(**{k: v for k, v in merged.items()
if k in allowed_properties})


class LegendPlot(ElementPlot):

legend_position = param.ObjectSelector(objects=["top_right",
"top_left",
"bottom_left",
"bottom_right",
'right', 'left',
'top', 'bottom'],
default="top_right",
doc="""
Allows selecting between a number of predefined legend position
options. The predefined options may be customized in the
legend_specs class attribute.""")


legend_cols = param.Integer(default=False, doc="""
Whether to lay out the legend as columns.""")


legend_specs = {'right': dict(pos='right', loc=(5, -40)),
'left': dict(pos='left', loc=(0, -40)),
'top': dict(pos='above', loc=(120, 5)),
'bottom': dict(pos='below', loc=(60, 0))}



class BokehMPLWrapper(ElementPlot):
"""
Wraps an existing HoloViews matplotlib plot and converts
Expand Down Expand Up @@ -710,22 +868,8 @@ def update_frame(self, key, ranges=None, element=None):
self.handles['plot'] = self._render_plot(element)


class OverlayPlot(GenericOverlayPlot, ElementPlot):

legend_position = param.ObjectSelector(objects=["top_right",
"top_left",
"bottom_left",
"bottom_right",
'right', 'left',
'top', 'bottom'],
default="top_right",
doc="""
Allows selecting between a number of predefined legend position
options. The predefined options may be customized in the
legend_specs class attribute.""")
class OverlayPlot(GenericOverlayPlot, LegendPlot):

legend_cols = param.Integer(default=False, doc="""
Whether to lay out the legend as columns.""")

tabs = param.Boolean(default=False, doc="""
Whether to display overlaid plots in separate panes""")
Expand All @@ -734,11 +878,6 @@ class OverlayPlot(GenericOverlayPlot, ElementPlot):

_update_handles = ['source']

legend_specs = {'right': dict(pos='right', loc=(5, -40)),
'left': dict(pos='left', loc=(0, -40)),
'top': dict(pos='above', loc=(120, 5)),
'bottom': dict(pos='below', loc=(60, 0))}

def _process_legend(self):
plot = self.handles['plot']
if not self.show_legend or len(plot.legend) == 0:
Expand Down
Loading

0 comments on commit 96be6a7

Please sign in to comment.