diff --git a/python/lance/__init__.py b/python/lance/__init__.py index aa16af0ac5..132472a4d5 100644 --- a/python/lance/__init__.py +++ b/python/lance/__init__.py @@ -95,9 +95,9 @@ def _is_plain_dataset(filesystem: pa.fs.FileSystem, uri: str): return filesystem.get_file_info(manifest).type == pa.fs.FileType.NotFound -def _get_versioned_dataset(filesystem: pa.fs.FileSystem, - uri: str, - version: Optional[int] = None): +def _get_versioned_dataset( + filesystem: pa.fs.FileSystem, uri: str, version: Optional[int] = None +): # Read the versioned dataset layout. has_version = True if version is None: diff --git a/python/lance/tests/util/test_versioning.py b/python/lance/tests/util/test_versioning.py index ee3267df1c..0cfd5b5bad 100644 --- a/python/lance/tests/util/test_versioning.py +++ b/python/lance/tests/util/test_versioning.py @@ -18,7 +18,14 @@ import pytz import lance -from lance.util.versioning import get_version_asof, compute_metric +from lance.util.versioning import ( + ColumnDiff, + LanceDiff, + RowDiff, + compute_metric, + diff, + get_version_asof, +) def test_get_version_asof(tmp_path): @@ -57,13 +64,42 @@ def _get_test_timestamps(naive): def test_compute_metric(tmp_path): - table1 = pa.Table.from_pylist([{"a": 1, "b": 2}]) base_dir = tmp_path / "test" - lance.write_dataset(table1, base_dir) - table2 = pa.Table.from_pylist([{"a": 100, "b": 200}]) - lance.write_dataset(table2, base_dir, mode="append") + _create_dataset(base_dir) def func(dataset): return dataset.to_table().to_pandas().max().to_frame().T + metrics = compute_metric(base_dir, func) assert "version" in metrics + + +def _create_dataset(base_dir): + table1 = pa.Table.from_pylist([{"a": 1, "b": 2}]) + lance.write_dataset(table1, base_dir) + table2 = pa.Table.from_pylist([{"a": 100, "b": 200}]) + lance.write_dataset(table2, base_dir, mode="append") + table3 = pa.Table.from_pylist([{"a": 100, "c": 100, "d": 200}]) + lance.dataset(base_dir).merge(table3, left_on="a", right_on="a") + + +def test_diff(tmp_path): + base_dir = tmp_path / "test" + _create_dataset(base_dir) + + d = diff(base_dir, 1, 2) + assert isinstance(d, LanceDiff) + + rows = d.rows_added(key="a") + assert isinstance(rows, RowDiff) + assert rows.count_rows == 1 + tbl = rows.head() + assert isinstance(tbl, pa.Table) + assert len(tbl) == 1 + + cols = d.columns_added() + assert isinstance(cols, ColumnDiff) + assert len(cols.schema) == 2 + tbl = cols.head() + assert isinstance(tbl, pa.Table) + assert len(tbl) == 2 diff --git a/python/lance/util/versioning.py b/python/lance/util/versioning.py index 42783267c7..96cb958767 100644 --- a/python/lance/util/versioning.py +++ b/python/lance/util/versioning.py @@ -21,11 +21,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + +import itertools from datetime import datetime, timezone +from functools import cached_property from pathlib import Path -from typing import Callable, Optional, Union +from typing import Callable, Union +import duckdb import pandas as pd +import pyarrow as pa from lance.lib import FileSystemDataset @@ -63,15 +69,17 @@ def get_version_asof(ds: FileSystemDataset, ts: [datetime, pd.Timestamp, str]) - raise ValueError(f"{ts} is earlier than the first version of this dataset") -def compute_metric(uri: [Path, str], - metric_func: Callable[[FileSystemDataset], pd.DataFrame], - versions: list = None, - with_version: Union[bool, str] = True) \ - -> pd.DataFrame: +def compute_metric( + uri: [Path, str], + metric_func: Callable[[FileSystemDataset], pd.DataFrame], + versions: list = None, + with_version: Union[bool, str] = True, +) -> pd.DataFrame: """ Compare metrics across versions of a dataset """ import lance + if versions is None: versions = lance.dataset(uri).versions() results = [] @@ -87,3 +95,122 @@ def compute_metric(uri: [Path, str], vdf[vcol_name] = v results.append(vdf) return pd.concat(results) + + +def diff(uri, v1: int, v2: int) -> LanceDiff: + import lance + + return LanceDiff(lance.dataset(uri, version=v1), lance.dataset(uri, version=v2)) + + +class LanceDiff: + def __init__(self, v1: FileSystemDataset, v2: FileSystemDataset): + self.v1 = v1 + self.v2 = v2 + + def __repr__(self): + return ( + "LanceDiff\n" + f" Added: {self.rows_added().count_rows} rows, " + f"{len(self.columns_added().schema)} columns" + ) + + def rows_added(self, key: str = None) -> RowDiff: + return RowDiff(self.v1, self.v2, key) + + def columns_added(self) -> ColumnDiff: + """ + Get the net new fields between v1 and v2. You can then + use the `schema` property to see new fields + and `head(n)` method to get the data in those fields + """ + v2_fields = _flat_schema(self.v2.schema) + v1_fields = set([f.name for f in _flat_schema(self.v1.schema)]) + new_fields = [f for f in v2_fields if f not in v1_fields] + return ColumnDiff(self.v2, new_fields) + + +class RowDiff: + """ + Row diff between two dataset versions using the specified join keys + """ + + def __init__( + self, + ds_start: FileSystemDataset, + ds_end: FileSystemDataset, + key: [str, list[str]], + ): + self.ds_start = ds_start + self.ds_end = ds_end + self.key = [key] if isinstance(key, str) else key + + def _query(self, projection: list[str], offset: int = 0, limit: int = 0) -> str: + join = " AND ".join([f"v2.{k}=v1.{k}" for k in self.key]) + query = ( + f"SELECT {','.join(projection)} FROM v2 " + f"LEFT JOIN v1 ON {join} " + f"WHERE v1.{self.key[0]} IS NULL" + ) + if offset > 0: + query += f" OFFSET {offset}" + if limit > 0: + query += f" LIMIT {limit}" + return query + + @cached_property + def count_rows(self) -> int: + """ + Return the number of rows in this diff + """ + return self.ds_end.count_rows() - self.ds_start.count_rows() + + def head(self, n: int = 10, columns: list[str] = None) -> pa.Table: + """ + Retrieve the rows in this diff as a pyarrow table + + Parameters + ---------- + n: int, default 10 + Get this many rows + columns: list[str], default None + Get all rows if not specified + """ + v1 = self.ds_start + v2 = self.ds_end + if columns is None: + columns = ["*"] + qry = self._query(columns, limit=n) + return duckdb.query(qry).to_arrow_table() + + +class ColumnDiff: + def __init__(self, ds: FileSystemDataset, fields: list[pa.Field]): + self.dataset = ds + self.fields = fields + + @property + def schema(self) -> pa.Schema: + """ + Flattened schema containing fields for this diff + """ + return pa.schema(self.fields) + + def head(self, n=10, columns=None) -> pa.Table: + """ + Return the first `n` rows for fields in this diff as a pyarrow Table + + Parameters + ---------- + n: int, default 10 + How many rows to return + columns: list[str], default None + If None then all fields are returned + """ + if columns is None: + columns = [f.name for f in self.fields] + return self.dataset.head(n, columns=columns) + + +def _flat_schema(schema): + return itertools.chain(*[f.flatten() for f in schema])