Skip to content

Commit

Permalink
fix: Support pandas ExtensionArray ordering (#6481)
Browse files Browse the repository at this point in the history
  • Loading branch information
flying-sheep authored Jan 17, 2025
1 parent e1f584b commit becb5f3
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 9 deletions.
2 changes: 2 additions & 0 deletions holoviews/core/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,8 @@ class Dataset(Element, metaclass=PipelineMeta):
_vdim_reductions = {}
_kdim_reductions = {}

interface: Interface

def __new__(cls, data=None, kdims=None, vdims=None, **kwargs):
"""
Allows casting a DynamicMap to an Element class like hv.Curve, by applying the
Expand Down
4 changes: 3 additions & 1 deletion holoviews/core/dimension.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
axis or map dimension. Also supplies the Dimensioned abstract
baseclass for classes that accept Dimension values.
"""
from __future__ import annotations

import builtins
import datetime as dt
import re
Expand Down Expand Up @@ -922,7 +924,7 @@ def dimensions(self, selection='all', label=False):
if label else dim for dim in dims]


def get_dimension(self, dimension, default=None, strict=False):
def get_dimension(self, dimension, default=None, strict=False) -> Dimension | None:
"""Get a Dimension object by name or index.
Args:
Expand Down
30 changes: 22 additions & 8 deletions holoviews/element/util.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from __future__ import annotations

import itertools
from typing import TYPE_CHECKING

import numpy as np
import pandas as pd
Expand All @@ -19,6 +22,11 @@
sort_topologically,
)

if TYPE_CHECKING:
from typing import TypeVar

Array = TypeVar("Array", np.ndarray, pd.api.extensions.ExtensionArray)


def split_path(path):
"""
Expand Down Expand Up @@ -126,18 +134,19 @@ class categorical_aggregate2d(Operation):
The grid interface types to use when constructing the gridded Dataset.""")

@classmethod
def _get_coords(cls, obj):
def _get_coords(cls, obj: Dataset):
"""
Get the coordinates of the 2D aggregate, maintaining the correct
sorting order.
"""
xdim, ydim = obj.dimensions(label=True)[:2]
xcoords = obj.dimension_values(xdim, False)
ycoords = obj.dimension_values(ydim, False)

if xcoords.dtype.kind not in 'SUO':
xcoords = np.sort(xcoords)
xcoords = sort_arr(xcoords)
if ycoords.dtype.kind not in 'SUO':
return xcoords, np.sort(ycoords)
return xcoords, sort_arr(ycoords)

# Determine global orderings of y-values using topological sort
grouped = obj.groupby(xdim, container_type=dict,
Expand All @@ -149,19 +158,18 @@ def _get_coords(cls, obj):
if len(vals) == 1:
orderings[vals[0]] = [vals[0]]
else:
for i in range(len(vals)-1):
p1, p2 = vals[i:i+2]
for p1, p2 in zip(vals[:-1], vals[1:]):
orderings[p1] = [p2]
if sort:
if vals.dtype.kind in ('i', 'f'):
sort = (np.diff(vals)>=0).all()
else:
sort = np.array_equal(np.sort(vals), vals)
sort = np.array_equal(sort_arr(vals), vals)
if sort or one_to_one(orderings, ycoords):
ycoords = np.sort(ycoords)
ycoords = sort_arr(ycoords)
elif not is_cyclic(orderings):
coords = list(itertools.chain(*sort_topologically(orderings)))
ycoords = coords if len(coords) == len(ycoords) else np.sort(ycoords)
ycoords = coords if len(coords) == len(ycoords) else sort_arr(ycoords)
return np.asarray(xcoords), np.asarray(ycoords)

def _aggregate_dataset(self, obj):
Expand Down Expand Up @@ -332,3 +340,9 @@ def connect_edges(graph):
end = end_ds.array(end_ds.kdims[:2])
paths.append(np.array([start[0], end[0]]))
return paths


def sort_arr(arr: Array) -> Array:
if isinstance(arr, pd.api.extensions.ExtensionArray):
return arr[arr.argsort()]
return np.sort(arr)
21 changes: 21 additions & 0 deletions holoviews/tests/plotting/bokeh/test_barplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,27 @@ def test_bars_not_continuous_data_list_custom_width(self):
plot = bokeh_renderer.get_plot(bars)
assert plot.handles["glyph"].width == 1

def test_bars_categorical_order(self):
cells_dtype = pd.CategoricalDtype(
pd.array(["~1M", "~10M", "~100M"], dtype="string"),
ordered=True,
)
df = pd.DataFrame(dict(
cells=cells_dtype.categories.astype(cells_dtype),
time=pd.array([2.99, 18.5, 835.2]),
function=pd.array(["read", "read", "read"]),
))

bars = Bars(df, ["function", "cells"], ["time"])
plot = bokeh_renderer.get_plot(bars)
x_factors = plot.handles["x_range"].factors

np.testing.assert_equal(x_factors, [
("read", "~1M"),
("read", "~10M"),
("read", "~100M"),
])

def test_bars_group(self):
samples = 100

Expand Down

0 comments on commit becb5f3

Please sign in to comment.