Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changhiskhan/datadiff #380

Merged
merged 3 commits into from
Dec 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions python/lance/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,9 @@ def _is_plain_dataset(filesystem: pa.fs.FileSystem, uri: str):
return filesystem.get_file_info(manifest).type == pa.fs.FileType.NotFound


def _get_versioned_dataset(filesystem: pa.fs.FileSystem,
uri: str,
version: Optional[int] = None):
def _get_versioned_dataset(
filesystem: pa.fs.FileSystem, uri: str, version: Optional[int] = None
):
# Read the versioned dataset layout.
has_version = True
if version is None:
Expand Down
46 changes: 41 additions & 5 deletions python/lance/tests/util/test_versioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,14 @@
import pytz

import lance
from lance.util.versioning import get_version_asof, compute_metric
from lance.util.versioning import (
ColumnDiff,
LanceDiff,
RowDiff,
compute_metric,
diff,
get_version_asof,
)


def test_get_version_asof(tmp_path):
Expand Down Expand Up @@ -57,13 +64,42 @@ def _get_test_timestamps(naive):


def test_compute_metric(tmp_path):
table1 = pa.Table.from_pylist([{"a": 1, "b": 2}])
base_dir = tmp_path / "test"
lance.write_dataset(table1, base_dir)
table2 = pa.Table.from_pylist([{"a": 100, "b": 200}])
lance.write_dataset(table2, base_dir, mode="append")
_create_dataset(base_dir)

def func(dataset):
return dataset.to_table().to_pandas().max().to_frame().T

metrics = compute_metric(base_dir, func)
assert "version" in metrics


def _create_dataset(base_dir):
table1 = pa.Table.from_pylist([{"a": 1, "b": 2}])
lance.write_dataset(table1, base_dir)
table2 = pa.Table.from_pylist([{"a": 100, "b": 200}])
lance.write_dataset(table2, base_dir, mode="append")
table3 = pa.Table.from_pylist([{"a": 100, "c": 100, "d": 200}])
lance.dataset(base_dir).merge(table3, left_on="a", right_on="a")


def test_diff(tmp_path):
base_dir = tmp_path / "test"
_create_dataset(base_dir)

d = diff(base_dir, 1, 2)
assert isinstance(d, LanceDiff)

rows = d.rows_added(key="a")
assert isinstance(rows, RowDiff)
assert rows.count_rows == 1
tbl = rows.head()
assert isinstance(tbl, pa.Table)
assert len(tbl) == 1

cols = d.columns_added()
assert isinstance(cols, ColumnDiff)
assert len(cols.schema) == 2
tbl = cols.head()
assert isinstance(tbl, pa.Table)
assert len(tbl) == 2
139 changes: 133 additions & 6 deletions python/lance/util/versioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import itertools
from datetime import datetime, timezone
from functools import cached_property
from pathlib import Path
from typing import Callable, Optional, Union
from typing import Callable, Union

import duckdb
import pandas as pd
import pyarrow as pa

from lance.lib import FileSystemDataset

Expand Down Expand Up @@ -63,15 +69,17 @@ def get_version_asof(ds: FileSystemDataset, ts: [datetime, pd.Timestamp, str]) -
raise ValueError(f"{ts} is earlier than the first version of this dataset")


def compute_metric(uri: [Path, str],
metric_func: Callable[[FileSystemDataset], pd.DataFrame],
versions: list = None,
with_version: Union[bool, str] = True) \
-> pd.DataFrame:
def compute_metric(
uri: [Path, str],
metric_func: Callable[[FileSystemDataset], pd.DataFrame],
versions: list = None,
with_version: Union[bool, str] = True,
) -> pd.DataFrame:
"""
Compare metrics across versions of a dataset
"""
import lance

if versions is None:
versions = lance.dataset(uri).versions()
results = []
Expand All @@ -87,3 +95,122 @@ def compute_metric(uri: [Path, str],
vdf[vcol_name] = v
results.append(vdf)
return pd.concat(results)


def diff(uri, v1: int, v2: int) -> LanceDiff:
import lance

return LanceDiff(lance.dataset(uri, version=v1), lance.dataset(uri, version=v2))


class LanceDiff:
def __init__(self, v1: FileSystemDataset, v2: FileSystemDataset):
self.v1 = v1
self.v2 = v2

def __repr__(self):
return (
"LanceDiff\n"
f" Added: {self.rows_added().count_rows} rows, "
f"{len(self.columns_added().schema)} columns"
)

def rows_added(self, key: str = None) -> RowDiff:
return RowDiff(self.v1, self.v2, key)

def columns_added(self) -> ColumnDiff:
"""
Get the net new fields between v1 and v2. You can then
use the `schema` property to see new fields
and `head(n)` method to get the data in those fields
"""
v2_fields = _flat_schema(self.v2.schema)
v1_fields = set([f.name for f in _flat_schema(self.v1.schema)])
new_fields = [f for f in v2_fields if f not in v1_fields]
return ColumnDiff(self.v2, new_fields)


class RowDiff:
"""
Row diff between two dataset versions using the specified join keys
"""

def __init__(
self,
ds_start: FileSystemDataset,
ds_end: FileSystemDataset,
key: [str, list[str]],
):
self.ds_start = ds_start
self.ds_end = ds_end
self.key = [key] if isinstance(key, str) else key

def _query(self, projection: list[str], offset: int = 0, limit: int = 0) -> str:
join = " AND ".join([f"v2.{k}=v1.{k}" for k in self.key])
query = (
f"SELECT {','.join(projection)} FROM v2 "
f"LEFT JOIN v1 ON {join} "
f"WHERE v1.{self.key[0]} IS NULL"
)
if offset > 0:
query += f" OFFSET {offset}"
if limit > 0:
query += f" LIMIT {limit}"
return query

@cached_property
def count_rows(self) -> int:
"""
Return the number of rows in this diff
"""
return self.ds_end.count_rows() - self.ds_start.count_rows()

def head(self, n: int = 10, columns: list[str] = None) -> pa.Table:
"""
Retrieve the rows in this diff as a pyarrow table

Parameters
----------
n: int, default 10
Get this many rows
columns: list[str], default None
Get all rows if not specified
"""
v1 = self.ds_start
v2 = self.ds_end
if columns is None:
columns = ["*"]
qry = self._query(columns, limit=n)
return duckdb.query(qry).to_arrow_table()


class ColumnDiff:
def __init__(self, ds: FileSystemDataset, fields: list[pa.Field]):
self.dataset = ds
self.fields = fields

@property
def schema(self) -> pa.Schema:
"""
Flattened schema containing fields for this diff
"""
return pa.schema(self.fields)

def head(self, n=10, columns=None) -> pa.Table:
"""
Return the first `n` rows for fields in this diff as a pyarrow Table

Parameters
----------
n: int, default 10
How many rows to return
columns: list[str], default None
If None then all fields are returned
"""
if columns is None:
columns = [f.name for f in self.fields]
return self.dataset.head(n, columns=columns)


def _flat_schema(schema):
return itertools.chain(*[f.flatten() for f in schema])