Skip to content

Commit

Permalink
Ensure plotting code handles custom array types (#3792)
Browse files Browse the repository at this point in the history
  • Loading branch information
philippjfr authored Jun 25, 2019
1 parent f9e7740 commit 480a5e5
Show file tree
Hide file tree
Showing 8 changed files with 30 additions and 14 deletions.
4 changes: 2 additions & 2 deletions holoviews/plotting/bokeh/element.py
Original file line number Diff line number Diff line change
Expand Up @@ -1019,8 +1019,8 @@ def _apply_transforms(self, element, data, ranges, style, group=None):
data[k] = val

# If color is not valid colorspec add colormapper
numeric = isinstance(val, np.ndarray) and val.dtype.kind in 'uifMm'
if ('color' in k and isinstance(val, np.ndarray) and
numeric = isinstance(val, util.arraylike_types) and val.dtype.kind in 'uifMm'
if ('color' in k and isinstance(val, util.arraylike_types) and
(numeric or not validate('color', val))):
kwargs = {}
if val.dtype.kind not in 'ifMu':
Expand Down
4 changes: 2 additions & 2 deletions holoviews/plotting/bokeh/styles.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
cm, colors = None, None

from ...core.options import abbreviated_exception
from ...core.util import basestring
from ...core.util import basestring, arraylike_types
from ...util.transform import dim
from ..util import COLOR_ALIASES, RGB_HEX_REGEX, rgb2hex

Expand Down Expand Up @@ -133,7 +133,7 @@ def validate(style, value, scalar=False):
validator = get_validator(style)
if validator is None:
return None
if isinstance(value, (np.ndarray, list)):
if isinstance(value, arraylike_types+(list,)):
if scalar:
return False
return all(validator(v) for v in value)
Expand Down
7 changes: 4 additions & 3 deletions holoviews/plotting/bokeh/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from ...core.overlay import Overlay
from ...core.util import (
LooseVersion, _getargspec, basestring, callable_name, cftime_types,
cftime_to_timestamp, pd, unique_array, isnumeric)
cftime_to_timestamp, pd, unique_array, isnumeric, arraylike_types)
from ...core.spaces import get_nested_dmaps, DynamicMap
from ..util import dim_axis_label

Expand Down Expand Up @@ -98,7 +98,7 @@ def decode_bytes(array):
bokeh serialization errors
"""
if (sys.version_info.major == 2 or not len(array) or
(isinstance(array, np.ndarray) and array.dtype.kind != 'O')):
(isinstance(array, arraylike_types) and array.dtype.kind != 'O')):
return array
decoded = [v.decode('utf-8') if isinstance(v, bytes) else v for v in array]
if isinstance(array, np.ndarray):
Expand Down Expand Up @@ -603,7 +603,8 @@ def cds_column_replace(source, data):
needs to be updated. A replacement is required if untouched
columns are not the same length as the columns being updated.
"""
current_length = [len(v) for v in source.data.values() if isinstance(v, (list, np.ndarray))]
current_length = [len(v) for v in source.data.values()
if isinstance(v, (list,)+arraylike_types)]
new_length = [len(v) for v in data.values() if isinstance(v, (list, np.ndarray))]
untouched = [k for k in source.data if k not in data]
return bool(untouched and current_length and new_length and current_length[0] != new_length[0])
Expand Down
4 changes: 2 additions & 2 deletions holoviews/plotting/mpl/element.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,7 +587,7 @@ def _apply_transforms(self, element, ranges, style):
groups = [sg for sg in style_groups if k.startswith(sg)]
group = groups[0] if groups else None
prefix = '' if group is None else group+'_'
if (k in (prefix+'c', prefix+'color') and isinstance(val, np.ndarray)
if (k in (prefix+'c', prefix+'color') and isinstance(val, util.arraylike_types)
and not validate('color', val)):
new_style.pop(k)
self._norm_kwargs(element, ranges, new_style, v, val, prefix)
Expand Down Expand Up @@ -617,7 +617,7 @@ def _apply_transforms(self, element, ranges, style):
(prefix != 'edge' or getattr(self, 'filled', True))
and any(o.startswith(prefix+'face') for o in self.style_opts))

if k in (prefix+'c', prefix+'color') and isinstance(val, np.ndarray):
if k in (prefix+'c', prefix+'color') and isinstance(val, util.arraylike_types):
fill_style = new_style.get(prefix+'facecolor')
if fill_style and validate('color', fill_style):
new_style.pop('facecolor')
Expand Down
2 changes: 1 addition & 1 deletion holoviews/plotting/mpl/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def get_data(self, element, ranges, style):
style = self._apply_transforms(element, ranges, style)

cdim = element.get_dimension(self.color_index)
style_mapping = any(True for v in style.values() if isinstance(v, np.ndarray))
style_mapping = any(True for v in style.values() if isinstance(v, util.arraylike_types))
dims = element.kdims
xdim, ydim = dims
generic_dt_format = Dimension.type_formatters[np.datetime64]
Expand Down
5 changes: 3 additions & 2 deletions holoviews/plotting/mpl/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
nc_axis_available = False

from ...core.util import (
LooseVersion, _getargspec, basestring, cftime_types, is_number)
LooseVersion, _getargspec, arraylike_types, basestring,
cftime_types, is_number,)
from ...element import Raster, RGB, Polygons
from ..util import COLOR_ALIASES, RGB_HEX_REGEX

Expand Down Expand Up @@ -89,7 +90,7 @@ def validate(style, value, vectorized=True):
validator = get_validator(style)
if validator is None:
return None
if isinstance(value, (np.ndarray, list)) and vectorized:
if isinstance(value, arraylike_types+(list,)) and vectorized:
return all(validator(v) for v in value)
try:
valid = validator(value)
Expand Down
4 changes: 2 additions & 2 deletions holoviews/plotting/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from ..core.spaces import get_nested_streams
from ..core.util import (match_spec, wrap_tuple, basestring, get_overlay_spec,
unique_iterator, closest_match, is_number, isfinite,
python2sort, disable_constant)
python2sort, disable_constant, arraylike_types)
from ..streams import LinkedStream
from ..util.transform import dim

Expand Down Expand Up @@ -518,7 +518,7 @@ def map_colors(arr, crange, cmap, hex=True):
Maps an array of values to RGB hex strings, given
a color range and colormap.
"""
if isinstance(crange, np.ndarray):
if isinstance(crange, arraylike_types):
xsorted = np.argsort(crange)
ypos = np.searchsorted(crange, arr)
arr = xsorted[ypos]
Expand Down
14 changes: 14 additions & 0 deletions holoviews/tests/plotting/bokeh/testpointplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,20 @@ def test_point_categorical_color_op(self):
self.assertEqual(glyph.fill_color, {'field': 'color', 'transform': cmapper})
self.assertEqual(glyph.line_color, {'field': 'color', 'transform': cmapper})

def test_point_categorical_dtype_color_op(self):
df = pd.DataFrame(dict(sample_id=['subject 1', 'subject 2', 'subject 3', 'subject 4'], category=['apple', 'pear', 'apple', 'pear'], value=[1, 2, 3, 4]))
df['category'] = df['category'].astype('category')
points = Points(df, ['sample_id', 'value']).opts(color='category')
plot = bokeh_renderer.get_plot(points)
cds = plot.handles['cds']
glyph = plot.handles['glyph']
cmapper = plot.handles['color_color_mapper']
self.assertTrue(cmapper, CategoricalColorMapper)
self.assertEqual(cmapper.factors, ['apple', 'pear'])
self.assertEqual(np.asarray(cds.data['color']), np.array(['apple', 'pear', 'apple', 'pear']))
self.assertEqual(glyph.fill_color, {'field': 'color', 'transform': cmapper})
self.assertEqual(glyph.line_color, {'field': 'color', 'transform': cmapper})

def test_point_explicit_cmap_color_op(self):
points = Points([(0, 0), (0, 1), (0, 2)]).options(
color='y', cmap={0: 'red', 1: 'green', 2: 'blue'})
Expand Down

0 comments on commit 480a5e5

Please sign in to comment.