Skip to content

Commit

Permalink
✅ Improve the existing tests (#9)
Browse files Browse the repository at this point in the history
  • Loading branch information
davnn authored May 1, 2024
1 parent 9f64426 commit ee77dfd
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 38 deletions.
6 changes: 3 additions & 3 deletions assets/coverage.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
20 changes: 10 additions & 10 deletions src/nearness/_base/_neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,31 +87,32 @@ def __call__(cls, *_: Any, **kwargs: Any) -> "NearestNeighbors":
logger.info("Determined parameters as '%s'", cls._parameters_)

# now we set all the relevant attributes on the ``instance``, as they should not be class-bound.
# order is important here as ``__wrap_fit_method`` and ``__wrap_check_method`` depend on the set attributes
# order is important here as ``_wrap_fit_method`` and ``_wrap_check_method`` depend on the set attributes
obj = type.__call__(cls, **kwargs)
obj._parameters_, obj._config_ = cls._parameters_, cls._config_ # noqa: SLF001
# make sure that the wrapped methods are in sync when the config is changed after class instantiation
obj._config_.register_callback( # noqa: SLF001
"methods_require_fit",
partial(cls._check_callback, obj=obj),
)
obj.__fitted__ = False
cls._wrap_fit_method(obj)
cls._wrap_check_method(obj)
del cls._parameters_
del cls._config_
return obj

def _check_callback(cls, _: Any, *, obj: "NearestNeighbors") -> None:
"""A callback to refresh the ``__fitted__`` attributes when the 'methods_require_fit' attribute is set ."""
"""A callback to refresh the ``__fitted__`` attributes when the 'methods_require_fit' attribute is updated."""
if not obj.is_fitted: # if the object is already fitted there is no need to wrap any methods
cls._wrap_check_method(obj)

def _wrap_fit_method(cls, obj: "NearestNeighbors") -> None:
"""Wrap the ``fit`` method to ensure it sets the ``__fitted__`` attribute and unwraps the query methods.
"""Wrap the ``fit`` method to ensure it sets ``is_fitted`` to ``True``, which unwraps the query method checks.
This feels a bit like magic, but the alternatives would be to:
1. Add a decorator to every ``fit`` method that sets ``__fitted__``.
2. Set the ``__fitted__`` attribute on every ``fit``-like method and check for ``__fitted__`` everywhere.
1. Add a decorator to every ``fit`` method that sets ``is_fitted`` to ``True``.
2. Set ``is_fitted`` on every ``fit``-like method and check for ``is_fitted`` everywhere.
"""
fit = obj.fit

Expand Down Expand Up @@ -173,7 +174,7 @@ def _unwrap_check_method(cls, obj: "NearestNeighbors") -> None:
This is just a performance optimization, it retrieves the original method and removes the implicitly
generated function wrapper (decorator). It should be safe to unwrap the methods if ``__fitted__``
is only set in ``fit``, but unsafe when ``__fitted__`` is manually set to ``False`` after ``fit``.
is set in ``fit``, but unsafe when ``__fitted__`` is manually set to ``False`` after ``fit``.
"""
logger.debug("Starting to unwrap all fit checking methods.")
for method_name in obj._config_.methods_require_fit: # noqa: SLF001
Expand All @@ -197,8 +198,6 @@ class NearestNeighbors(metaclass=NearestNeighborsMeta):
and returning floating-point distances of equal type as output.
"""

__fitted__ = False

@abstractmethod
def fit(self, data: np.ndarray) -> Self:
"""Learn an index structure based on a matrix of points.
Expand Down Expand Up @@ -345,7 +344,8 @@ def is_fitted(self, value: bool) -> None:
if not value:
NearestNeighbors._wrap_check_method(self)

self.__fitted__ = value
# this variable is initialized in the metaclass
self.__fitted__ = value # type: ignore[reportUninitializedInstanceVariable]

@property
def config(self) -> "Config":
Expand All @@ -367,7 +367,7 @@ def _create_parameter_class(
empty = inspect.Parameter.empty
parameter_types = [(k, Any if (a := v.annotation) is empty else a) for k, v in parameters.items()]
parameter_values = {k: kwargs.get(k, v.default) for k, v in parameters.items()}
return make_dataclass("Parameters", parameter_types)(**parameter_values) # type: ignore[reportArgumentType]
return make_dataclass("Parameters", parameter_types)(**parameter_values)


def _create_check_wrapper(obj: NearestNeighbors, method: Callable[..., Any]) -> Callable[..., Any]:
Expand Down
7 changes: 1 addition & 6 deletions tests/test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,12 @@
nq=1, # number of single queries
)

# we only benchmark brute-force algorithms for now
candidates = {
"annoy": lambda: AnnoyNeighbors(metric="euclidean", n_search_neighbors=256),
"faiss-ivf-pq": lambda: FaissNeighbors(index_or_factory="OPQ8,IVF128,PQ8", sample_train_points=10_000),
"faiss-hnsw-pq": lambda: FaissNeighbors(index_or_factory="OPQ8,HNSW_PQ8"),
"faiss-nsg-pq": lambda: FaissNeighbors(index_or_factory="OPQ8,NSG_PQ8"),
"faiss-brute": lambda: FaissNeighbors(index_or_factory="Flat"),
"hnsw": lambda: HNSWNeighbors(metric="l2", n_threads=-1, n_search_neighbors=128, n_index_neighbors=256),
"hnsw-brute": lambda: HNSWNeighbors(metric="l2", n_threads=-1, use_bruteforce=True),
"jax": lambda: JaxNeighbors(compute_mode="use_mm_for_euclid_dist", approximate_recall_target=0.9),
"numpy": lambda: NumpyNeighbors(metric="minkowski", p=2, compute_mode="use_mm_for_euclid_dist"),
"scann": lambda: ScannNeighbors(search_parallel=True, use_tree=True, use_bruteforce=False, use_reorder=False),
"scann-brute": lambda: ScannNeighbors(search_parallel=True, use_bruteforce=True, use_tree=False, use_reorder=False),
"scipy": lambda: ScipyNeighbors(metric="euclidean"),
"sklearn": lambda: SklearnNeighbors(metric="euclidean", n_jobs=-1),
Expand Down
47 changes: 32 additions & 15 deletions tests/test_nearness.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class Data:
batch: np.ndarray


candidates = {
candidates_original = {
# comparing two equal models might seem unnecessary, but some tests are only performed on the ``implementation``
"sklearn": lambda: Candidate(
implementation=SklearnNeighbors(metric="minkowski", p=2), reference=SklearnNeighbors(metric="minkowski", p=2)
Expand Down Expand Up @@ -83,10 +83,15 @@ class Data:
),
}

candidates = [pytest_param_if_value_available(k, v) for k, v in candidates.items()]
candidates = [pytest_param_if_value_available(k, v) for k, v in candidates_original.items()]
DataGenerationMethod = Literal["random", "mds", "digits", "synthetic"]


def skip_if_missing(key: str) -> Any:
reason = f"Skipping test, {key} is not available."
return pytest.mark.skipif(value_is_missing(candidates_original[key]), reason=reason)


def make_data(
dim: int,
fit_size: int = 1,
Expand Down Expand Up @@ -188,9 +193,6 @@ def test_fit_return_self(data, candidate):
@hypothesis.given(data_and_neighbors=neighbors_strategy())
@pytest.mark.parametrize("candidate", candidates)
def test_query(data_and_neighbors, candidate):
if candidate.reference is None:
pytest.skip()

data, n_neighbors = data_and_neighbors

candidate.reference.fit(data.fit)
Expand Down Expand Up @@ -220,9 +222,6 @@ def test_query_idx_dist(data_and_neighbors, candidate):
@hypothesis.given(data_and_neighbors=neighbors_strategy())
@pytest.mark.parametrize("candidate", candidates)
def test_query_batch(data_and_neighbors, candidate):
if candidate.reference is None:
pytest.skip()

data, n_neighbors = data_and_neighbors

candidate.reference.fit(data.fit)
Expand Down Expand Up @@ -271,6 +270,7 @@ def test_save_load_identity(data_and_neighbors, candidate, tmp_path, request):
assert approx_equal(dist_save, dist_load)


@skip_if_missing("annoy")
@hypothesis.given(data_and_neighbors=neighbors_strategy())
@hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture])
def test_annoy_save_load_identity(data_and_neighbors, tmp_path):
Expand Down Expand Up @@ -320,11 +320,30 @@ def test_faiss_index_creation(data_and_neighbors):
assert approx_equal(dist_str, dist_fun, dist_raw)


@skip_if_missing("torch")
@hypothesis.given(data_and_neighbors=neighbors_strategy())
def test_torch_tensor_input(data_and_neighbors):
import torch

data, n_neighbors = data_and_neighbors
fit, query, batch = map(torch.from_numpy, [data.fit, data.query, data.batch])
candidate = candidates_original["torch"]()
candidate.implementation.fit(fit)
candidate.reference.fit(data.fit)

query_idx_numpy, query_dist_numpy = candidate.reference.query(data.query, n_neighbors=n_neighbors)
query_idx_torch, query_dist_torch = candidate.implementation.query(query, n_neighbors=n_neighbors)
batch_idx_numpy, batch_dist_numpy = candidate.reference.query_batch(data.batch, n_neighbors=n_neighbors)
batch_idx_torch, batch_dist_torch = candidate.implementation.query_batch(batch, n_neighbors=n_neighbors)

assert array_equal(query_idx_numpy, query_idx_torch.numpy())
assert approx_equal(batch_dist_numpy, batch_dist_torch.numpy())


def test_thread_safety():
# Number of threads to create
num_threads = 100
original_thread_ids = set()
captured_thread_ids = set()
invalid_thread_ids = set()

class Model(NearestNeighbors):
def __init__(self, *, thread_id: int):
Expand All @@ -348,10 +367,9 @@ def create_instance():
# Set the thread id on the model
model = Model(thread_id=thread_id)
captured_id = model.parameters.thread_id
captured_thread_ids.add(captured_id)

# Perform any assertions on the instance
assert thread_id == model.parameters.thread_id
if thread_id != captured_id:
invalid_thread_ids.add((thread_id, captured_id))

# Create and start threads
threads = [threading.Thread(target=create_instance) for _ in range(num_threads)]
Expand All @@ -366,8 +384,7 @@ def create_instance():
thread.join()

# Perform any additional assertions or checks
assert len(captured_thread_ids) == len(original_thread_ids)
assert captured_thread_ids == original_thread_ids
assert len(invalid_thread_ids) == 0


def test_missing_abstract_method():
Expand Down
15 changes: 11 additions & 4 deletions tests/utilities.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from copy import copy

import pytest
import numpy as np
from typing_extensions import Any
Expand All @@ -23,6 +21,15 @@ def array_equal(x: np.ndarray, y: np.ndarray, *arrays: np.ndarray, equal_nan: bo

def pytest_param_if_value_available(key: str, lazy_value: Any) -> Any: # type: ignore[ANN401]
try:
return pytest.param(copy(lazy_value()), id=key)
return pytest.param(lazy_value(), id=key)
except NameError:
reason = f"Skipping test, {key} is not available."
return pytest.param(None, id=key, marks=pytest.mark.skip(reason=reason))


def value_is_missing(lazy_value: Any) -> bool:
try:
lazy_value()
return False
except NameError:
return pytest.param(None, id=key, marks=pytest.mark.skip)
return True

0 comments on commit ee77dfd

Please sign in to comment.