diff --git a/CHANGELOG.md b/CHANGELOG.md
index 89c95f856..3d57f0e3a 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -4,6 +4,7 @@
### Added
+- Refactoring KNN Shapley values with the new sampler architecture [PR #610](https://github.com/aai-institute/pyDVL/pull/610).
- Refactoring MSR Banzhaf semivalues with the new sampler architecture.
[PR #605](https://github.com/aai-institute/pyDVL/pull/605)
- Refactoring group-testing shapley values with new sampler architecture
@@ -36,7 +37,7 @@
[PR #597](https://github.com/aai-institute/pyDVL/pull/597)
- Fix a bug in the calculation of variance estimates for MSR Banzhaf
[PR #605](https://github.com/aai-institute/pyDVL/pull/605)
-
+- Fix a bug in KNN Shapley values. See [Issue 613](https://github.com/aai-institute/pyDVL/issues/613) for details.
### Changed
diff --git a/src/pydvl/valuation/methods/knn_shapley.py b/src/pydvl/valuation/methods/knn_shapley.py
index 89ffefdd7..f2b26f3c9 100644
--- a/src/pydvl/valuation/methods/knn_shapley.py
+++ b/src/pydvl/valuation/methods/knn_shapley.py
@@ -13,3 +13,144 @@
the VLDB Endowment, Vol. 12, No. 11, pp. 1610–1623.
"""
+from __future__ import annotations
+
+import numpy as np
+from joblib import Parallel, delayed
+from numpy.typing import NDArray
+from sklearn.neighbors import NearestNeighbors
+from tqdm.auto import tqdm
+from typing_extensions import Self
+
+from pydvl.utils.status import Status
+from pydvl.valuation.base import Valuation
+from pydvl.valuation.dataset import Dataset
+from pydvl.valuation.result import ValuationResult
+from pydvl.valuation.utility import KNNClassifierUtility
+
+
+class KNNShapleyValuation(Valuation):
+ """Computes exact Shapley values for a KNN classifier.
+
+ This implements the method described in
+ (Jia, R. et al., 2019)1.
+ It exploits the local structure of K-Nearest Neighbours to reduce the number
+ of calls to the utility function to a constant number per index, thus
+ reducing computation time to $O(n)$.
+
+ Args:
+ utility: KNNUtility with a KNN model to extract parameters from. The object
+ will not be modified nor used other than to call [get_params()](
+ )
+ progress: Whether to display a progress bar.
+
+ """
+
+ def __init__(self, utility: KNNClassifierUtility, progress: bool = True):
+ super().__init__()
+ self.utility = utility
+ self.progress = progress
+
+ config = self.utility.model.get_params()
+ self.n_neighbors = config["n_neighbors"]
+
+ del config["weights"]
+ self.helper_model = NearestNeighbors(**config)
+
+ def fit(self, data: Dataset) -> Self:
+ """Calculate exact shapley values for a KNN model on a dataset.
+
+ This fit method bypasses direct evaluations of the utility function and
+ calculates the Shapley values directly.
+
+ In contrast to other data valuation models, the runtime increases linearly
+ with the size of the test dataset.
+
+ Calculating the KNN valuation is a computationally expensive task that
+ can be parallelized. To do so, call the `fit()` method inside a
+ `joblib.parallel_config` context manager as follows:
+
+ ```python
+ from joblib import parallel_config
+
+ with parallel_config(n_jobs=4):
+ valuation.fit(data)
+ ```
+
+ """
+ self.helper_model = self.helper_model.fit(data.x)
+ n_obs = len(data.x)
+ n_test = len(self.utility.test_data)
+
+ generator = zip(self.utility.test_data.x, self.utility.test_data.y)
+
+ generator_with_progress = tqdm(
+ generator,
+ total=n_test,
+ disable=not self.progress,
+ position=0,
+ )
+
+ with Parallel(return_as="generator") as parallel:
+ results = parallel(
+ delayed(self._compute_values_for_one_test_point)(
+ self.helper_model, x, y, data.y
+ )
+ for x, y in generator_with_progress
+ )
+ values = np.zeros(n_obs)
+ for res in results:
+ values += res
+ values /= n_test
+
+ res = ValuationResult(
+ algorithm="knn_shapley",
+ status=Status.Converged,
+ values=values,
+ data_names=data.data_names,
+ )
+
+ self.result = res
+ return self
+
+ @staticmethod
+ def _compute_values_for_one_test_point(
+ helper_model: NearestNeighbors, x: NDArray, y: int, y_train: NDArray
+ ) -> np.ndarray:
+ """Compute the Shapley value for a single test data point.
+
+ The shapley values of the whole test set are the average of the shapley values
+ of the single test data points.
+
+ Args:
+ helper_model: A fitted NearestNeighbors model.
+ x: A single test data point.
+ y: The correct label of the test data point.
+ y_train: The training labels.
+
+ Returns:
+ The Shapley values for the test data point.
+
+ """
+ n_obs = len(y_train)
+ n_neighbors = helper_model.get_params()["n_neighbors"]
+
+ # sorts data indices from close to far
+ sorted_indices = helper_model.kneighbors(
+ x.reshape(1, -1), n_neighbors=n_obs, return_distance=False
+ )[0]
+
+ values = np.zeros(n_obs)
+
+ idx = sorted_indices[-1]
+ values[idx] = float(y_train[idx] == y) / n_obs
+ # reverse range because we want to go from far to close
+ for i in range(n_obs - 1, 0, -1):
+ prev_idx = sorted_indices[i]
+ idx = sorted_indices[i - 1]
+ values[idx] = values[prev_idx]
+ values[idx] += (int(y_train[idx] == y) - int(y_train[prev_idx] == y)) / max(
+ n_neighbors, i
+ )
+
+ return values
diff --git a/src/pydvl/valuation/methods/msr_banzhaf.py b/src/pydvl/valuation/methods/msr_banzhaf.py
index c32ba2e1a..28c867ffe 100644
--- a/src/pydvl/valuation/methods/msr_banzhaf.py
+++ b/src/pydvl/valuation/methods/msr_banzhaf.py
@@ -161,9 +161,6 @@ def _combine_results(
estimates, misleading update counts and even wrong values if no further
precaution is taken.
- TODO: Verify that the two running means are statistically independent (which is
- assumed in the aggregation of variances).
-
Args:
pos_result: The result of the positive updates.
neg_result: The result of the negative updates.
@@ -172,6 +169,9 @@ def _combine_results(
Returns:
The combined valuation result.
+ TODO: Verify that the two running means are statistically independent (which is
+ assumed in the aggregation of variances).
+
"""
# define counts as minimum of the two counts (see docstring)
counts = np.minimum(pos_result.counts, neg_result.counts)
diff --git a/src/pydvl/valuation/samplers/msr.py b/src/pydvl/valuation/samplers/msr.py
index be5a7a32a..697c7f213 100644
--- a/src/pydvl/valuation/samplers/msr.py
+++ b/src/pydvl/valuation/samplers/msr.py
@@ -1,7 +1,7 @@
from __future__ import annotations
from dataclasses import dataclass
-from typing import Callable
+from typing import Callable, List
import numpy as np
@@ -71,7 +71,7 @@ class MSREvaluationStrategy(EvaluationStrategy):
def process(
self, batch: SampleBatch, is_interrupted: NullaryPredicate
- ) -> list[ValueUpdate]:
+ ) -> List[MSRValueUpdate]:
updates = []
for sample in batch:
updates.extend(self._process_sample(sample))
@@ -79,7 +79,7 @@ def process(
break
return updates
- def _process_sample(self, sample: Sample) -> list[ValueUpdate]:
+ def _process_sample(self, sample: Sample) -> List[MSRValueUpdate]:
u_value = self.utility(sample)
mask = np.zeros(self.n_indices, dtype=bool)
mask[sample.subset] = True
diff --git a/src/pydvl/valuation/scorers/__init__.py b/src/pydvl/valuation/scorers/__init__.py
index 3b8d24d74..c179ba94e 100644
--- a/src/pydvl/valuation/scorers/__init__.py
+++ b/src/pydvl/valuation/scorers/__init__.py
@@ -1,4 +1,5 @@
from .base import *
from .classwise import *
+from .knn import *
from .supervised import *
from .utils import *
diff --git a/src/pydvl/valuation/scorers/knn.py b/src/pydvl/valuation/scorers/knn.py
new file mode 100644
index 000000000..dcfff1471
--- /dev/null
+++ b/src/pydvl/valuation/scorers/knn.py
@@ -0,0 +1,41 @@
+"""Specialized scorer for k-nearest neighbors models."""
+
+from typing import cast
+
+import numpy as np
+from numpy.typing import NDArray
+from sklearn.neighbors import KNeighborsClassifier
+
+from pydvl.valuation.scorers import SupervisedScorer, SupervisedScorerCallable
+
+
+class KNNClassifierScorer(SupervisedScorer):
+ """Scorer for KNN classifier models based on the KNN likelihood.
+
+ Typically, users will not create instances of this class directly but indirectly
+ by using `pydvl.valuation.utility.KNNUtility`.
+
+ Args:
+ test_data: The test data to evaluate the model on.
+
+ """
+
+ def __init__(self, test_data):
+ def scoring(model: KNeighborsClassifier, X: NDArray, y: NDArray) -> float:
+ probs = model.predict_proba(X)
+ label_to_pos = {label: i for i, label in enumerate(model.classes_)}
+ likelihoods = []
+ for i in range(len(y)):
+ if y[i] not in label_to_pos:
+ likelihoods.append(0.0)
+ else:
+ likelihoods.append(probs[i, label_to_pos[y[i]]])
+ return cast(float, np.mean(likelihoods))
+
+ super().__init__(
+ scoring=cast(SupervisedScorerCallable, scoring),
+ test_data=test_data,
+ default=0.0,
+ range=(0, 1),
+ name="KNN Scorer",
+ )
diff --git a/src/pydvl/valuation/utility/__init__.py b/src/pydvl/valuation/utility/__init__.py
index 275f62c83..7ba1fa559 100644
--- a/src/pydvl/valuation/utility/__init__.py
+++ b/src/pydvl/valuation/utility/__init__.py
@@ -18,5 +18,6 @@
compute the score.
"""
+from .modelutility import * # isort: skip
+from .knn import *
from .learning import *
-from .modelutility import *
diff --git a/src/pydvl/valuation/utility/knn.py b/src/pydvl/valuation/utility/knn.py
new file mode 100644
index 000000000..fb375b4d3
--- /dev/null
+++ b/src/pydvl/valuation/utility/knn.py
@@ -0,0 +1,67 @@
+from __future__ import annotations
+
+from sklearn.neighbors import KNeighborsClassifier
+
+from pydvl.utils.caching import CacheBackend, CachedFuncConfig
+from pydvl.valuation.dataset import Dataset
+from pydvl.valuation.scorers import KNNClassifierScorer
+from pydvl.valuation.types import Sample
+from pydvl.valuation.utility import ModelUtility
+
+__all__ = ["KNNClassifierUtility"]
+
+
+class KNNClassifierUtility(ModelUtility[Sample, KNeighborsClassifier]):
+ """Utility object for KNN Classifiers.
+
+ The utility function is the likelihood of the true class given the model's
+ prediction.
+
+ This works both as a Utility object for general game theoretic valuation methods and
+ for specialized valuation methods for KNN classifiers.
+
+ Args:
+ model: A KNN classifier model.
+ test_data: The test data to evaluate the model on.
+ catch_errors: set to `True` to catch the errors when `fit()` fails. This
+ could happen in several steps of the pipeline, e.g. when too little
+ training data is passed, which happens often during Shapley value
+ calculations. When this happens, the [scorer's default
+ value][pydvl.valuation.scorers.SupervisedScorer] is returned as a score and
+ computation continues.
+ show_warnings: Set to `False` to suppress warnings thrown by `fit()`.
+ cache_backend: Optional instance of [CacheBackend][
+ pydvl.utils.caching.base.CacheBackend] used to wrap the _utility method of
+ the Utility instance. By default, this is set to None and that means that
+ the utility evaluations will not be cached.
+ cached_func_options: Optional configuration object for cached utility
+ evaluation.
+ clone_before_fit: If `True`, the model will be cloned before calling
+ `fit()`.
+
+ """
+
+ def __init__(
+ self,
+ model: KNeighborsClassifier,
+ test_data: Dataset,
+ *,
+ catch_errors: bool = True,
+ show_warnings: bool = False,
+ cache_backend: CacheBackend | None = None,
+ cached_func_options: CachedFuncConfig | None = None,
+ clone_before_fit: bool = True,
+ ):
+ scorer = KNNClassifierScorer(test_data)
+
+ self.test_data = test_data
+
+ super().__init__(
+ model=model,
+ scorer=scorer,
+ catch_errors=catch_errors,
+ show_warnings=show_warnings,
+ cache_backend=cache_backend,
+ cached_func_options=cached_func_options,
+ clone_before_fit=clone_before_fit,
+ )
diff --git a/src/pydvl/valuation/utility/modelutility.py b/src/pydvl/valuation/utility/modelutility.py
index 0c86b6eca..b35535942 100644
--- a/src/pydvl/valuation/utility/modelutility.py
+++ b/src/pydvl/valuation/utility/modelutility.py
@@ -173,7 +173,6 @@ def _utility(self, sample: SampleT) -> float:
model or the scorer returns [numpy.NaN][]. Otherwise, the score
of the model.
"""
-
if self.training_data is None:
raise ValueError("No training data provided")
diff --git a/tests/valuation/methods/test_knn_shapley.py b/tests/valuation/methods/test_knn_shapley.py
new file mode 100644
index 000000000..c5c8ad4df
--- /dev/null
+++ b/tests/valuation/methods/test_knn_shapley.py
@@ -0,0 +1,74 @@
+import numpy as np
+import pytest
+from joblib import parallel_config
+from sklearn import datasets
+from sklearn.neighbors import KNeighborsClassifier
+
+from pydvl.utils.dataset import Dataset as OldDataset
+from pydvl.utils.utility import Utility as OldUtility
+from pydvl.valuation.dataset import Dataset
+from pydvl.valuation.methods import DataShapleyValuation, KNNShapleyValuation
+from pydvl.valuation.samplers import PermutationSampler
+from pydvl.valuation.stopping import MinUpdates
+from pydvl.valuation.utility import KNNClassifierUtility
+from pydvl.value.shapley.knn import knn_shapley as old_knn_shapley
+
+
+@pytest.fixture(scope="module")
+def data():
+ return Dataset.from_sklearn(
+ datasets.load_iris(),
+ train_size=0.05,
+ random_state=1234,
+ stratify_by_target=True,
+ )
+
+
+@pytest.fixture(scope="module")
+def montecarlo_results(data):
+ model = KNeighborsClassifier(n_neighbors=5)
+ data_train, data_test = data
+ utility = KNNClassifierUtility(model=model, test_data=data_test)
+ sampler = PermutationSampler(seed=42)
+ montecarlo_valuation = DataShapleyValuation(
+ utility,
+ sampler=sampler,
+ is_done=MinUpdates(1000),
+ progress=False,
+ )
+ return montecarlo_valuation.fit(data_train).values()
+
+
+@pytest.mark.parametrize("n_jobs", [1, 2])
+def test_against_montecarlo(n_jobs, data, montecarlo_results):
+ model = KNeighborsClassifier(n_neighbors=5)
+ data_train, data_test = data
+ utility = KNNClassifierUtility(model=model, test_data=data_test)
+ valuation = KNNShapleyValuation(utility, progress=False)
+
+ with parallel_config(n_jobs=n_jobs):
+ results = valuation.fit(data_train).values()
+
+ np.testing.assert_allclose(
+ results.values, montecarlo_results.values, atol=1e-2, rtol=1e-2
+ )
+
+
+@pytest.mark.xfail(reason="Suspected bug in old implementation.")
+def test_old_vs_new(seed, data):
+ model = KNeighborsClassifier(n_neighbors=5)
+ old_data = OldDataset.from_sklearn(
+ datasets.load_iris(),
+ train_size=0.05,
+ random_state=seed,
+ stratify_by_target=True,
+ )
+ old_u = OldUtility(model=model, data=old_data)
+ old_values = old_knn_shapley(old_u, progress=False).values
+
+ data_train, data_test = data
+ utility = KNNClassifierUtility(model=model, test_data=data_test)
+ new_valuation = KNNShapleyValuation(utility, progress=False)
+ new_values = new_valuation.fit(data_train).values().values
+
+ np.testing.assert_allclose(new_values, old_values, atol=1e-2, rtol=1e-2)