Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enh: Support polars in pn.cache #7472

Merged
merged 6 commits into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 60 additions & 8 deletions panel/io/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,13 @@ def clear(self, func_hashes: list[str | None]=[None]) -> None:
bytes, str, float, int, bool, bytearray, type(None)
)

_NP_SIZE_LARGE = 100_000
_ARRAY_SIZE_LARGE = 100_000

_NP_SAMPLE_SIZE = 100_000
_ARRAY_SAMPLE_SIZE = 100_000

_PANDAS_ROWS_LARGE = 100_000
_DATAFRAME_ROWS_LARGE = 100_000

_PANDAS_SAMPLE_SIZE = 100_000
_DATAFRAME_SAMPLE_SIZE = 100_000

if sys.platform == 'win32':
_TIME_FN = time.perf_counter
Expand Down Expand Up @@ -125,8 +125,8 @@ def _pandas_hash(obj):
if not isinstance(obj, (pd.Series, pd.DataFrame)):
obj = pd.Series(obj)

if len(obj) >= _PANDAS_ROWS_LARGE:
obj = obj.sample(n=_PANDAS_SAMPLE_SIZE, random_state=0)
if len(obj) >= _DATAFRAME_ROWS_LARGE:
obj = obj.sample(n=_DATAFRAME_SAMPLE_SIZE, random_state=0)
try:
if isinstance(obj, pd.DataFrame):
return ((b"%s" % pd.util.hash_pandas_object(obj).sum())
Expand All @@ -138,13 +138,62 @@ def _pandas_hash(obj):
# it contains unhashable objects.
return b"%s" % pickle.dumps(obj, pickle.HIGHEST_PROTOCOL)

def _polars_combine_hash_expr(columns):
philippjfr marked this conversation as resolved.
Show resolved Hide resolved
"""
Inspired by pd.core.util.hashing.combine_hash_arrays,
rewritten to a Polars expression.
"""
import polars as pl

mult = pl.lit(1000003, dtype=pl.UInt64)
initial_value = pl.lit(0x345678, dtype=pl.UInt64)
increment = pl.lit(82520, dtype=pl.UInt64)
final_addition = pl.lit(97531, dtype=pl.UInt64)

out = initial_value
num_items = len(columns)
for i, col_name in enumerate(columns):
col = pl.col(col_name).hash(seed=0)
inverse_i = pl.lit(num_items - i, dtype=pl.UInt64)
out = (out ^ col) * mult
mult = mult + (increment + inverse_i + inverse_i)

return out + final_addition

def _polars_hash(obj):
import polars as pl

hash_type = type(obj).__name__.encode()

if isinstance(obj, pl.Series):
obj = obj.to_frame()

columns = obj.collect_schema().names()
hash_columns = _container_hash(columns)

# LazyFrame does not support len and sample
if hash_type != b"LazyFrame" and len(obj) >= _DATAFRAME_ROWS_LARGE:
obj = obj.sample(n=_DATAFRAME_SAMPLE_SIZE, seed=0)
elif hash_type == b"LazyFrame":
count = obj.select(pl.col(columns[0]).count()).collect().item()
if count >= _DATAFRAME_ROWS_LARGE:
obj = obj.select(pl.all().sample(n=_DATAFRAME_SAMPLE_SIZE, seed=0))

hash_expr = _polars_combine_hash_expr(columns)
hash_data = obj.select(hash_expr).sum()
if hash_type == b"LazyFrame":
hash_data = hash_data.collect()
hash_data = _int_to_bytes(hash_data.item())

return hash_type + hash_data + hash_columns

def _numpy_hash(obj):
h = hashlib.new("md5")
h.update(_generate_hash(obj.shape))
if obj.size >= _NP_SIZE_LARGE:
if obj.size >= _ARRAY_SIZE_LARGE:
import numpy as np
state = np.random.RandomState(0)
obj = state.choice(obj.flat, size=_NP_SAMPLE_SIZE)
obj = state.choice(obj.flat, size=_ARRAY_SAMPLE_SIZE)
h.update(obj.tobytes())
return h.digest()

Expand Down Expand Up @@ -180,6 +229,9 @@ def _io_hash(obj):
'builtins.dict_items' : lambda obj: _container_hash(dict(obj)),
'builtins.getset_descriptor' : lambda obj: obj.__qualname__.encode(),
"numpy.ufunc" : lambda obj: obj.__name__.encode(),
"polars.series.series.Series": _polars_hash,
"polars.dataframe.frame.DataFrame": _polars_hash,
"polars.lazyframe.frame.LazyFrame": _polars_hash,
# Functions
inspect.isbuiltin : lambda obj: obj.__name__.encode(),
inspect.ismodule : lambda obj: obj.__name__,
Expand Down
28 changes: 28 additions & 0 deletions panel/tests/io/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,34 @@ def test_series_hash():
series2.iloc[0] = 3.14
assert not hashes_equal(series1, series2)

def test_polars_dataframe_hash():
pl = pytest.importorskip("polars")
data = {
"A": [0.0, 1.0, 2.0, 3.0, 4.0],
"B": [0.0, 1.0, 0.0, 1.0, 0.0],
"C": ["foo1", "foo2", "foo3", "foo4", "foo5"],
}
# DataFrame
df1, df2 = pl.DataFrame(data), pl.DataFrame(data)
assert hashes_equal(df1, df2)
df2 = df2.with_columns(A=pl.col("A").sort(descending=True))
assert not hashes_equal(df1, df2)

# Lazy DataFrame
df1, df2 = pl.LazyFrame(data), pl.LazyFrame(data)
assert hashes_equal(df1, df2)
df2 = df2.with_columns(A=pl.col("A").sort(descending=True))
assert not hashes_equal(df1, df2)

def test_polars_series_hash():
pl = pytest.importorskip("polars")
ser1 = pl.Series([0.0, 1.0, 2.0, 3.0, 4.0])
ser2 = ser1.clone()

assert hashes_equal(ser1, ser2)
ser2 = ser2.replace(0.0, 3.14)
assert not hashes_equal(ser1, ser2)

def test_ufunc_hash():
assert hashes_equal(np.absolute, np.absolute)
assert not hashes_equal(np.sin, np.cos)
Expand Down
1 change: 1 addition & 0 deletions pixi.toml
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ ipympl = "*"
ipyvuetify = "*"
ipywidgets_bokeh = "*"
numba = "*"
polars = "*"
reacton = "*"
scipy = "*"
textual = "*"
Expand Down
Loading