-
-
Notifications
You must be signed in to change notification settings - Fork 404
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make HeatMap more general #849
Changes from 11 commits
bec8024
339f988
1d3d57e
69a9793
efd4bd9
17651f0
f3543e6
843387c
29f47c9
3f4b073
d68485f
143c301
0a91dce
03cebf6
844c1ad
fb4b207
fcac23e
d380d08
dcae11f
9082070
f5998f2
050c4c7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,7 +14,7 @@ | |
from ..core.util import pd | ||
from .chart import Curve | ||
from .tabular import Table | ||
from .util import compute_edges, toarray | ||
from .util import compute_edges, toarray, get_2d_aggregate | ||
|
||
try: | ||
from ..core.data import PandasInterface | ||
|
@@ -365,16 +365,14 @@ def dimension_values(self, dimension, expanded=True, flat=True): | |
return super(QuadMesh, self).dimension_values(idx) | ||
|
||
|
||
|
||
class HeatMap(Dataset, Element2D): | ||
""" | ||
HeatMap is an atomic Element used to visualize two dimensional | ||
parameter spaces. It supports sparse or non-linear spaces, dynamically | ||
upsampling them to a dense representation, which can be visualized. | ||
|
||
A HeatMap can be initialized with any dict or NdMapping type with | ||
two-dimensional keys. Once instantiated the dense representation is | ||
available via the .data property. | ||
two-dimensional keys. | ||
""" | ||
|
||
group = param.String(default='HeatMap', constant=True) | ||
|
@@ -383,85 +381,18 @@ class HeatMap(Dataset, Element2D): | |
|
||
vdims = param.List(default=[Dimension('z')]) | ||
|
||
def __init__(self, data, extents=None, **params): | ||
depth = 1 | ||
|
||
def __init__(self, data, **params): | ||
super(HeatMap, self).__init__(data, **params) | ||
data, self.raster = self._compute_raster() | ||
self.data = data.data | ||
self.interface = data.interface | ||
self.depth = 1 | ||
if extents is None: | ||
(d1, d2) = self.raster.shape[:2] | ||
self.extents = (0, 0, d2, d1) | ||
else: | ||
self.extents = extents | ||
|
||
|
||
def _compute_raster(self): | ||
if self.interface.gridded: | ||
return self, np.flipud(self.dimension_values(2, flat=False)) | ||
d1keys = self.dimension_values(0, False) | ||
d2keys = self.dimension_values(1, False) | ||
coords = [(d1, d2, np.NaN) for d1 in d1keys for d2 in d2keys] | ||
dtype = 'dataframe' if pd else 'dictionary' | ||
dense_data = Dataset(coords, kdims=self.kdims, vdims=self.vdims, datatype=[dtype]) | ||
concat_data = self.interface.concatenate([dense_data, Dataset(self)], datatype=dtype) | ||
with warnings.catch_warnings(): | ||
warnings.filterwarnings('ignore', r'Mean of empty slice') | ||
data = concat_data.aggregate(self.kdims, np.nanmean) | ||
array = data.dimension_values(2).reshape(len(d1keys), len(d2keys)) | ||
return data, np.flipud(array.T) | ||
|
||
|
||
def __setstate__(self, state): | ||
if '_data' in state: | ||
data = state['_data'] | ||
if isinstance(data, NdMapping): | ||
items = [tuple(k)+((v,) if np.isscalar(v) else tuple(v)) | ||
for k, v in data.items()] | ||
kdims = state['kdims'] if 'kdims' in state else self.kdims | ||
vdims = state['vdims'] if 'vdims' in state else self.vdims | ||
data = Dataset(items, kdims=kdims, vdims=vdims).data | ||
elif isinstance(data, Dataset): | ||
data = data.data | ||
kdims = data.kdims | ||
vdims = data.vdims | ||
state['data'] = data | ||
state['kdims'] = kdims | ||
state['vdims'] = vdims | ||
self.__dict__ = state | ||
|
||
if isinstance(self.data, NdElement): | ||
self.interface = NdElementInterface | ||
elif isinstance(self.data, np.ndarray): | ||
self.interface = ArrayInterface | ||
elif util.is_dataframe(self.data): | ||
self.interface = PandasInterface | ||
elif isinstance(self.data, dict): | ||
self.interface = DictInterface | ||
self.depth = 1 | ||
data, self.raster = self._compute_raster() | ||
self.interface = data.interface | ||
self.data = data.data | ||
if 'extents' not in state: | ||
(d1, d2) = self.raster.shape[:2] | ||
self.extents = (0, 0, d2, d1) | ||
|
||
super(HeatMap, self).__setstate__(state) | ||
|
||
def dense_keys(self): | ||
d1keys = self.dimension_values(0, False) | ||
d2keys = self.dimension_values(1, False) | ||
return list(zip(*[(d1, d2) for d1 in d1keys for d2 in d2keys])) | ||
|
||
|
||
def dframe(self, dense=False): | ||
if dense: | ||
keys1, keys2 = self.dense_keys() | ||
dense_map = self.clone({(k1, k2): self._data.get((k1, k2), np.NaN) | ||
for k1, k2 in product(keys1, keys2)}) | ||
return dense_map.dframe() | ||
return super(HeatMap, self).dframe() | ||
self.gridded = get_2d_aggregate(self) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice to see how much There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That said, it isn't immediately obvious that |
||
|
||
@property | ||
def raster(self): | ||
self.warning("The .raster attribute on HeatMap is deprecated, " | ||
"the 2D aggregate is now computed dynamically " | ||
"during plotting.") | ||
return self.gridded.dimension_values(2, flat=False) | ||
|
||
|
||
class Image(SheetCoordinateSystem, Raster): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,19 @@ | ||
import numpy as np | ||
|
||
from ..core import Dataset, OrderedDict | ||
from ..core.util import pd, is_nan | ||
|
||
try: | ||
import dask | ||
except: | ||
dask = None | ||
|
||
try: | ||
import xarray as xr | ||
except: | ||
xr = None | ||
|
||
|
||
def toarray(v, index_value=False): | ||
""" | ||
Interface helper function to turn dask Arrays into numpy arrays as | ||
|
@@ -30,3 +39,60 @@ def compute_edges(edges): | |
raise ValueError('Centered bins have to be of equal width.') | ||
edges -= width/2. | ||
return np.concatenate([edges, [edges[-1]+width]]) | ||
|
||
|
||
def reduce_fn(x): | ||
""" | ||
Aggregation function to get the first non-zero value. | ||
""" | ||
values = x.values if pd and isinstance(x, pd.Series) else x | ||
for v in values: | ||
if not is_nan(v): | ||
return v | ||
return np.NaN | ||
|
||
|
||
def get_2d_aggregate(obj): | ||
""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Perhaps this would be better expressed as an operation? Then maybe it could have a minimal docstring example in the class docstring? |
||
Generates a categorical 2D aggregate by inserting NaNs at all | ||
cross-product locations that do not already have a value assigned. | ||
Returns a 2D gridded Dataset object. | ||
""" | ||
if obj.interface.gridded: | ||
return obj | ||
elif obj.ndims > 2: | ||
raise Exception("Cannot aggregate more than two dimensions") | ||
|
||
dims = obj.dimensions(label=True) | ||
xdim, ydim = dims[:2] | ||
nvdims = len(dims) - 2 | ||
d1keys = obj.dimension_values(xdim, False) | ||
d2keys = obj.dimension_values(ydim, False) | ||
|
||
is_sorted = np.array_equal(np.sort(d1keys), d1keys) | ||
if is_sorted: | ||
grouped = obj.groupby(xdim, container_type=OrderedDict, | ||
group_type=Dataset).values() | ||
for group in grouped: | ||
d2vals = group.dimension_values(ydim) | ||
is_sorted &= np.array_equal(d2vals, np.sort(d2vals)) | ||
|
||
if is_sorted: | ||
d1keys, d2keys = np.sort(d1keys), np.sort(d2keys) | ||
coords = [(d1, d2) + (np.NaN,)*nvdims for d2 in d2keys for d1 in d1keys] | ||
|
||
dtype = 'dataframe' if pd else 'dictionary' | ||
dense_data = Dataset(coords, kdims=obj.kdims, vdims=obj.vdims, datatype=[dtype]) | ||
concat_data = obj.interface.concatenate([dense_data, Dataset(obj)], datatype=dtype) | ||
agg = concat_data.reindex([xdim, ydim]).aggregate([xdim, ydim], reduce_fn) | ||
shape = (len(d2keys), len(d1keys)) | ||
grid_data = {xdim: d1keys, ydim: d2keys} | ||
|
||
for vdim in dims[2:]: | ||
data = agg.dimension_values(vdim).reshape(shape) | ||
data = np.ma.array(data, mask=np.logical_not(np.isfinite(data))) | ||
grid_data[vdim] = data | ||
|
||
grid_type = 'xarray' if xr else 'grid' | ||
return agg.clone(grid_data, datatype=[grid_type]) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,13 @@ | ||
import numpy as np | ||
import param | ||
|
||
from ...core.util import cartesian_product | ||
from bokeh.models.mappers import LinearColorMapper | ||
try: | ||
from bokeh.models.mappers import LogColorMapper | ||
except ImportError: | ||
LogColorMapper = None | ||
|
||
from ...core.util import cartesian_product, is_nan, unique_array | ||
from ...element import Image, Raster, RGB | ||
from ..renderer import SkipRendering | ||
from ..util import map_colors | ||
|
@@ -130,26 +136,31 @@ class HeatmapPlot(ColorbarPlot): | |
def _axes_props(self, plots, subplots, element, ranges): | ||
dims = element.dimensions() | ||
labels = self._get_axis_labels(dims) | ||
xvals, yvals = [element.dimension_values(i, False) | ||
agg = element.gridded | ||
xvals, yvals = [unique_array(agg.dimension_values(i, False)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I thought gridded There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, good point, no longer any need for the unique_array here. |
||
for i in range(2)] | ||
if self.invert_yaxis: yvals = yvals[::-1] | ||
plot_ranges = {'x_range': [str(x) for x in xvals], | ||
'y_range': [str(y) for y in yvals]} | ||
return ('auto', 'auto'), labels, plot_ranges | ||
|
||
|
||
def get_data(self, element, ranges=None, empty=False): | ||
x, y, z = element.dimensions(label=True) | ||
x, y, z = element.dimensions(label=True)[:3] | ||
aggregate = element.gridded | ||
style = self.style[self.cyclic_index] | ||
cmapper = self._get_colormapper(element.vdims[0], element, ranges, style) | ||
if empty: | ||
data = {x: [], y: [], z: [], 'color': []} | ||
data = {x: [], y: [], z: []} | ||
else: | ||
zvals = np.rot90(element.raster, 3).flatten() | ||
xvals, yvals = [[str(v) for v in element.dimension_values(i)] | ||
zvals = aggregate.dimension_values(z) | ||
xvals, yvals = [[str(v) for v in aggregate.dimension_values(i)] | ||
for i in range(2)] | ||
data = {x: xvals, y: yvals, z: zvals} | ||
|
||
if 'hover' in self.tools+self.default_tools: | ||
for vdim in element.vdims[1:]: | ||
data[vdim.name] = ['' if is_nan(v) else v | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wondering if an empty string really suggests NaN. 'NaN' would be explicit but might look noisy. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point, I'm now using masked arrays to represent the data, in matplotlib the NaNs are therefore represented by There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I think |
||
for v in aggregate.dimension_values(vdim)] | ||
return (data, {'x': x, 'y': y, 'fill_color': {'field': z, 'transform': cmapper}, | ||
'height': 1, 'width': 1}) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,7 +8,7 @@ | |
|
||
from ...core import CompositeOverlay, Element | ||
from ...core import traversal | ||
from ...core.util import match_spec, max_range, unique_iterator | ||
from ...core.util import match_spec, max_range, unique_iterator, unique_array | ||
from ...element.raster import Image, Raster, RGB | ||
from .element import ColorbarPlot, OverlayPlot | ||
from .plot import MPLPlot, GridPlot | ||
|
@@ -105,20 +105,19 @@ def _annotate_plot(self, ax, annotations): | |
handles = {} | ||
for plot_coord, text in annotations.items(): | ||
handles[plot_coord] = ax.annotate(text, xy=plot_coord, | ||
xycoords='axes fraction', | ||
xycoords='data', | ||
horizontalalignment='center', | ||
verticalalignment='center') | ||
return handles | ||
|
||
|
||
def _annotate_values(self, element): | ||
val_dim = element.vdims[0] | ||
vals = np.rot90(element.raster, 3).flatten() | ||
vals = element.dimension_values(2) | ||
d1uniq, d2uniq = [element.dimension_values(i, False) for i in range(2)] | ||
num_x, num_y = len(d1uniq), len(d2uniq) | ||
xstep, ystep = 1.0/num_x, 1.0/num_y | ||
xpos = np.linspace(xstep/2., 1.0-xstep/2., num_x) | ||
ypos = np.linspace(ystep/2., 1.0-ystep/2., num_y) | ||
xpos = np.linspace(0.5, num_x-0.5, num_x) | ||
ypos = np.linspace(0.5, num_y-0.5, num_y) | ||
plot_coords = product(xpos, ypos) | ||
annotations = {} | ||
for plot_coord, v in zip(plot_coords, vals): | ||
|
@@ -130,21 +129,19 @@ def _annotate_values(self, element): | |
|
||
def _compute_ticks(self, element, ranges): | ||
xdim, ydim = element.kdims | ||
dim1_keys, dim2_keys = [element.dimension_values(i, False) | ||
agg = element.gridded | ||
dim1_keys, dim2_keys = [unique_array(agg.dimension_values(i, False)) | ||
for i in range(2)] | ||
num_x, num_y = len(dim1_keys), len(dim2_keys) | ||
x0, y0, x1, y1 = element.extents | ||
xstep, ystep = ((x1-x0)/num_x, (y1-y0)/num_y) | ||
xpos = np.linspace(x0+xstep/2., x1-xstep/2., num_x) | ||
ypos = np.linspace(y0+ystep/2., y1-ystep/2., num_y) | ||
xpos = np.linspace(.5, num_x-0.5, num_x) | ||
ypos = np.linspace(.5, num_y-0.5, num_y) | ||
xlabels = [xdim.pprint_value(k) for k in dim1_keys] | ||
ylabels = [ydim.pprint_value(k) for k in dim2_keys] | ||
return list(zip(xpos, xlabels)), list(zip(ypos, ylabels)) | ||
|
||
|
||
def init_artists(self, ax, plot_args, plot_kwargs): | ||
l, r, b, t = plot_kwargs['extent'] | ||
ax.set_aspect(float(r - l)/(t-b)) | ||
ax.set_aspect(plot_kwargs.pop('aspect', 1)) | ||
|
||
handles = {} | ||
annotations = plot_kwargs.pop('annotations', None) | ||
|
@@ -156,18 +153,25 @@ def init_artists(self, ax, plot_args, plot_kwargs): | |
|
||
def get_data(self, element, ranges, style): | ||
_, style, axis_kwargs = super(HeatMapPlot, self).get_data(element, ranges, style) | ||
mask = np.logical_not(np.isfinite(element.raster)) | ||
data = np.ma.array(element.raster, mask=mask) | ||
style['annotations'] = self._annotate_values(element) | ||
aggregate = element.gridded | ||
data = np.flipud(aggregate.dimension_values(2, flat=False)) | ||
shape = data.shape | ||
cmap_name = style.pop('cmap', None) | ||
cmap = copy.copy(plt.cm.get_cmap('gray' if cmap_name is None else cmap_name)) | ||
cmap.set_bad('w', 1.) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Might want to make this a plot option at some point instead of hard coding 'w'. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Again good point, indeed we already expose this via There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I also find it curious that you are using |
||
style['cmap'] = cmap | ||
style['aspect'] = shape[0]/shape[1] | ||
style['extent'] = (0, shape[0], 0, shape[1]) | ||
style['annotations'] = self._annotate_values(aggregate) | ||
return [data], style, axis_kwargs | ||
|
||
|
||
def update_handles(self, key, axis, element, ranges, style): | ||
im = self.handles['artist'] | ||
data, style, axis_kwargs = self.get_data(element, ranges, style) | ||
l, r, b, t = style['extent'] | ||
im.set_data(data[0]) | ||
im.set_extent((l, r, b, t)) | ||
shape = data[0].shape | ||
im.set_extent((0, shape[1], 0, shape[0])) | ||
im.set_clim((style['vmin'], style['vmax'])) | ||
if 'norm' in style: | ||
im.norm = style['norm'] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I might have forgotten...what is this
depth
class attribute?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this may be wrong now, will have to look into it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wasn't needed at all in the end, removed it.