Skip to content

Commit

Permalink
Merge pull request #1137 from ioam/categorical_cmapper
Browse files Browse the repository at this point in the history
Added support for categorical colormapping in bokeh backend
  • Loading branch information
jlstevens authored Feb 23, 2017
2 parents 885c1bc + 3a8adf4 commit ed01c3c
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 46 deletions.
38 changes: 15 additions & 23 deletions holoviews/plotting/bokeh/chart.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from .util import get_cmap, mpl_to_bokeh, update_plot, rgb2hex, bokeh_version


class PointPlot(ColorbarPlot):
class PointPlot(LegendPlot, ColorbarPlot):

color_index = param.ClassSelector(default=3, class_=(basestring, int),
allow_None=True, doc="""
Expand Down Expand Up @@ -58,10 +58,22 @@ def get_data(self, element, ranges=None, empty=False):
mapping = dict(x=dims[xidx], y=dims[yidx])
data = {}

xdim, ydim = dims[xidx], dims[yidx]
data[xdim] = [] if empty else element.dimension_values(xidx)
data[ydim] = [] if empty else element.dimension_values(yidx)
self._categorize_data(data, (xdim, ydim), element.dimensions())

cdim = element.get_dimension(self.color_index)
if cdim:
mapper = self._get_colormapper(cdim, element, ranges, style)
data[cdim.name] = [] if empty else element.dimension_values(cdim)
cdata = data[cdim.name] if cdim.name in data else element.dimension_values(cdim)
factors = None
if isinstance(cdata, list) or cdata.dtype.kind in 'OSU':
factors = list(np.unique(cdata))
mapper = self._get_colormapper(cdim, element, ranges, style,
factors)
data[cdim.name] = cdata
if factors is not None:
mapping['legend'] = {'field': cdim.name}
mapping['color'] = {'field': cdim.name,
'transform': mapper}

Expand All @@ -85,10 +97,6 @@ def get_data(self, element, ranges=None, empty=False):
data[map_key] = np.sqrt(sizes)
mapping['size'] = map_key

xdim, ydim = dims[xidx], dims[yidx]
data[xdim] = [] if empty else element.dimension_values(xidx)
data[ydim] = [] if empty else element.dimension_values(yidx)
self._categorize_data(data, (xdim, ydim), element.dimensions())
self._get_hover_data(data, element, empty)
return data, mapping

Expand Down Expand Up @@ -520,22 +528,6 @@ def initialize_plot(self, ranges=None, plot=None, plots=None, source=None):

return plot

def _process_legend(self, plot):
legend = plot.legend[0]
if not self.show_legend:
legend.items[:] = []
else:
plot.legend.orientation = 'horizontal' if self.legend_cols else 'vertical'
pos = self.legend_position
if pos in self.legend_specs:
opts = self.legend_specs[pos]
plot.legend[:] = []
legend.plot = None
legend.location = opts['loc']
plot.add_layout(legend, opts['pos'])
else:
legend.location = pos

def update_frame(self, key, ranges=None, plot=None, element=None):
"""
Updates an existing plot with data corresponding
Expand Down
70 changes: 49 additions & 21 deletions holoviews/plotting/bokeh/element.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from bokeh.models.mappers import LinearColorMapper
try:
from bokeh.models import ColorBar
from bokeh.models.mappers import LogColorMapper
from bokeh.models.mappers import LogColorMapper, CategoricalColorMapper
except ImportError:
LogColorMapper, ColorBar = None, None
from bokeh.models import LogTicker, BasicTicker
Expand All @@ -34,7 +34,8 @@
from ..util import dynamic_update, get_sources
from .plot import BokehPlot
from .util import (mpl_to_bokeh, convert_datetime, update_plot, get_tab_title,
bokeh_version, mplcmap_to_palette, py2js_tickformatter)
bokeh_version, mplcmap_to_palette, py2js_tickformatter,
rgba_tuple)

if bokeh_version >= '0.12':
from bokeh.models import FuncTickFormatter
Expand Down Expand Up @@ -851,6 +852,8 @@ class ColorbarPlot(ElementPlot):
major_tick_line_color='black')

def _draw_colorbar(self, plot, color_mapper):
if CategoricalColorMapper and isinstance(color_mapper, CategoricalColorMapper):
return
if LogColorMapper and isinstance(color_mapper, LogColorMapper):
ticker = LogTicker()
else:
Expand All @@ -870,13 +873,11 @@ def _draw_colorbar(self, plot, color_mapper):
self.handles['colorbar'] = color_bar


def _get_colormapper(self, dim, element, ranges, style):
def _get_colormapper(self, dim, element, ranges, style, factors=None):
# The initial colormapper instance is cached the first time
# and then only updated
if dim is None:
return None
low, high = ranges.get(dim.name, element.range(dim.name))
palette = mplcmap_to_palette(style.pop('cmap', 'viridis'))
if self.adjoined:
cmappers = self.adjoined.traverse(lambda x: (x.handles.get('color_dim'),
x.handles.get('color_mapper')))
Expand All @@ -887,30 +888,40 @@ def _get_colormapper(self, dim, element, ranges, style):
return cmapper
else:
return None
colors = self.clipping_colors
if isinstance(low, (bool, np.bool_)): low = int(low)
if isinstance(high, (bool, np.bool_)): high = int(high)
opts = {'low': low, 'high': high}
color_opts = [('NaN', 'nan_color'), ('max', 'high_color'), ('min', 'low_color')]
for name, opt in color_opts:
color = colors.get(name)
if not color:
continue
elif isinstance(color, tuple):
color = [int(c*255) if i<3 else c for i, c in enumerate(color)]
opts[opt] = color

ncolors = None if factors is None else len(factors)
low, high = ranges.get(dim.name, element.range(dim.name))
palette = mplcmap_to_palette(style.pop('cmap', 'viridis'), ncolors)
colors = {k: rgba_tuple(v) for k, v in self.clipping_colors.items()}
colormapper, opts = self._get_cmapper_opts(low, high, factors, colors)

if 'color_mapper' in self.handles:
cmapper = self.handles['color_mapper']
cmapper.palette = palette
cmapper.update(**opts)
else:
colormapper = LogColorMapper if self.logz else LinearColorMapper
cmapper = colormapper(palette, **opts)
cmapper = colormapper(palette=palette, **opts)
self.handles['color_mapper'] = cmapper
self.handles['color_dim'] = dim
return cmapper


def _get_cmapper_opts(self, low, high, factors, colors):
if factors is None:
colormapper = LogColorMapper if self.logz else LinearColorMapper
if isinstance(low, (bool, np.bool_)): low = int(low)
if isinstance(high, (bool, np.bool_)): high = int(high)
opts = {'low': low, 'high': high}
color_opts = [('NaN', 'nan_color'), ('max', 'high_color'), ('min', 'low_color')]
opts.update({opt: colors[name] for name, opt in color_opts if name in colors})
else:
colormapper = CategoricalColorMapper
opts = dict(factors=factors)
if 'NaN' in colors:
opts['nan_color'] = colors['NaN']
return colormapper, opts


def _init_glyph(self, plot, mapping, properties):
"""
Returns a Bokeh glyph object and optionally creates a colorbar.
Expand Down Expand Up @@ -942,16 +953,33 @@ class LegendPlot(ElementPlot):
options. The predefined options may be customized in the
legend_specs class attribute.""")


legend_cols = param.Integer(default=False, doc="""
Whether to lay out the legend as columns.""")


legend_specs = {'right': dict(pos='right', loc=(5, -40)),
'left': dict(pos='left', loc=(0, -40)),
'top': dict(pos='above', loc=(120, 5)),
'bottom': dict(pos='below', loc=(60, 0))}

def _process_legend(self, plot=None):
plot = plot or self.handles['plot']
if not plot.legend:
return
legend = plot.legend[0]
if not self.show_legend:
legend.items[:] = []
else:
plot.legend.orientation = 'horizontal' if self.legend_cols else 'vertical'
pos = self.legend_position
if pos in self.legend_specs:
opts = self.legend_specs[pos]
plot.legend[:] = []
legend.plot = None
legend.location = opts['loc']
plot.add_layout(legend, opts['pos'])
else:
legend.location = pos



class BokehMPLWrapper(ElementPlot):
Expand Down
14 changes: 13 additions & 1 deletion holoviews/plotting/bokeh/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,26 @@ def rgb2hex(rgb):
return "#{0:02x}{1:02x}{2:02x}".format(*(int(v*255) for v in rgb))


def mplcmap_to_palette(cmap):
def rgba_tuple(rgba):
"""
Ensures RGB(A) tuples in the range 0-1 are scaled to 0-255.
"""
if isinstance(rgba, tuple):
return [int(c*255) if i<3 else c for i, c in enumerate(rgba)]
else:
return rgba


def mplcmap_to_palette(cmap, ncolors=None):
"""
Converts a matplotlib colormap to palette of RGB hex strings."
"""
if colors is None:
raise ValueError("Using cmaps on objects requires matplotlib.")
with abbreviated_exception():
colormap = cm.get_cmap(cmap) #choose any matplotlib colormap here
if ncolors:
return [rgb2hex(colormap(i)) for i in np.linspace(0, 1, ncolors)]
return [rgb2hex(m) for m in colormap(np.arange(colormap.N))]


Expand Down
13 changes: 12 additions & 1 deletion tests/testplotinstantiation.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@
Div, ColumnDataSource, FactorRange, Range1d, Row, Column,
ToolbarBox, Spacer
)
from bokeh.models.mappers import LinearColorMapper, LogColorMapper
from bokeh.models.mappers import (LinearColorMapper, LogColorMapper,
CategoricalColorMapper)
from bokeh.models.tools import HoverTool
from bokeh.plotting import Figure
except:
Expand Down Expand Up @@ -297,6 +298,16 @@ def test_points_colormapping(self):
points = Points(np.random.rand(10, 4), vdims=['a', 'b'])
self._test_colormapping(points, 3)

def test_points_colormapping_categorical(self):
points = Points([(i, i*2, i*3, chr(65+i)) for i in range(10)],
vdims=['a', 'b'])
plot = bokeh_renderer.get_plot(points)
plot.initialize_plot()
fig = plot.state
cmapper = plot.handles['color_mapper']
self.assertIsInstance(cmapper, CategoricalColorMapper)
self.assertEqual(cmapper.factors, list(points['b']))

def test_image_colormapping(self):
img = Image(np.random.rand(10, 10))(plot=dict(logz=True))
self._test_colormapping(img, 2, True)
Expand Down

0 comments on commit ed01c3c

Please sign in to comment.