Skip to content

Commit

Permalink
⬆️ Bump safecheck and dev dependencies (#11)
Browse files Browse the repository at this point in the history
  • Loading branch information
davnn authored Sep 18, 2024
1 parent d815a9c commit a64b7e2
Show file tree
Hide file tree
Showing 6 changed files with 1,206 additions and 1,176 deletions.
2,334 changes: 1,182 additions & 1,152 deletions poetry.lock

Large diffs are not rendered by default.

22 changes: 11 additions & 11 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ classifiers = [

[tool.poetry.dependencies]
python = ">=3.10,<3.13"
safecheck = "^0.3"
safecheck = ">=0.3,<0.5"
joblib = "^1"
typing-extensions = ">=4"
numpy = { version = ">=1", optional = true }
Expand Down Expand Up @@ -66,18 +66,18 @@ scann = ["scann"]
jax = ["jax"]

[tool.poetry.group.dev.dependencies]
pytest = "^8.2.0"
pytest = "^8.3.3"
pytest-html = "^4.1.1"
hypothesis = "^6.100.2"
coverage = "^7.5.0"
hypothesis = "^6.112.1"
coverage = "^7.6.1"
pytest-cov = "^5.0.0"
coverage-badge = "^1.1.1"
ruff = "^0.4.2"
pre-commit = "^3.7.0"
black = "^24.4.2"
pyright = "^1.1.360"
bandit = "^1.7.8"
safety = "^3.1.0"
coverage-badge = "^1.1.2"
ruff = "^0.6.5"
pre-commit = "^3.8.0"
black = "^24.8.0"
pyright = "^1.1.381"
bandit = "^1.7.9"
safety = "^3.2.7"
notebook = "^7.0.4"
pytest-benchmark = "^4.0.0"

Expand Down
2 changes: 1 addition & 1 deletion src/nearness/_autofaiss.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,5 +95,5 @@ def query_batch(
points: Float[NumpyArray, "m d"],
n_neighbors: int,
) -> tuple[Int64[NumpyArray, "m {n_neighbors}"], Float32[NumpyArray, "m {n_neighbors}"]]:
dist, idx = self._index.search(points, n_neighbors) # type: ignore[reportGeneralTypeIssues]
dist, idx = self._index.search(points, n_neighbors) # type: ignore[reportCallIssue]
return idx, dist
12 changes: 6 additions & 6 deletions src/nearness/_base/_neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ def __new__(cls, name: str, bases: tuple[type, ...], attrs: dict[str, Any]) -> t
"""Check the signature of ``__init__`` to ensure keyword-only arguments."""
if "__init__" in attrs:
parameters = inspect.signature(attrs["__init__"]).parameters
for name, parameter in parameters.items():
if name != "self" and (kind := parameter.kind) is not inspect.Parameter.KEYWORD_ONLY:
for param_name, parameter in parameters.items():
if param_name != "self" and (kind := parameter.kind) is not inspect.Parameter.KEYWORD_ONLY:
msg = (
"Only keyword-only arguments are allowed for classes inheriting from 'nearness."
f"NearestNeighbors', but found parameter '{parameter}' of kind '{kind}'."
Expand Down Expand Up @@ -89,9 +89,9 @@ def __call__(cls, *_: Any, **kwargs: Any) -> "NearestNeighbors":
# now we set all the relevant attributes on the ``instance``, as they should not be class-bound.
# order is important here as ``_wrap_fit_method`` and ``_wrap_check_method`` depend on the set attributes
obj = type.__call__(cls, **kwargs)
obj._parameters_, obj._config_ = cls._parameters_, cls._config_ # noqa: SLF001
obj._parameters_, obj._config_ = cls._parameters_, cls._config_
# make sure that the wrapped methods are in sync when the config is changed after class instantiation
obj._config_.register_callback( # noqa: SLF001
obj._config_.register_callback(
"methods_require_fit",
partial(cls._check_callback, obj=obj),
)
Expand Down Expand Up @@ -144,7 +144,7 @@ def _wrap_check_method(cls, obj: "NearestNeighbors") -> None:
"""
logger.debug("Starting to wrap methods to enable fit checking.")
available_attributes = dir(obj)
methods_to_wrap = obj._config_.methods_require_fit # noqa: SLF001
methods_to_wrap = obj._config_.methods_require_fit
for attribute_name in available_attributes:
attribute = getattr(obj, attribute_name)
has_check = hasattr(attribute, "__check__")
Expand Down Expand Up @@ -186,7 +186,7 @@ def _unwrap_check_method(cls, obj: "NearestNeighbors") -> None:
is set in ``fit``, but unsafe when ``__fitted__`` is manually set to ``False`` after ``fit``.
"""
logger.debug("Starting to unwrap all fit checking methods.")
for method_name in obj._config_.methods_require_fit: # noqa: SLF001
for method_name in obj._config_.methods_require_fit:
# we set an __requires_fit__ attribute on the wrapper, because using ``__wrapped__`` alone is not
# safe (methods also use ``__wrapped__`` starting with Python 3.10)
if hasattr(obj, method_name) and hasattr(method := getattr(obj, method_name), "__check__"):
Expand Down
8 changes: 4 additions & 4 deletions src/nearness/_faiss.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,17 +50,17 @@ def __init__(
def fit(self, data: Float[NumpyArray, "n d"]) -> "FaissNeighbors":
_, dim = data.shape
self._index = self._create_index(dim)
self._index.train(data) # type: ignore[reportGeneralTypeIssues]
self._index.train(data) # type: ignore[reportCallIssue]

if self.parameters.add_data_on_fit:
# data might be added directly on fit, or using the ``add`` method
self._index.add(data) # type: ignore[reportGeneralTypeIssues]
self._index.add(data) # type: ignore[reportCallIssue]

return self

@typecheck
def add(self, data: Float[NumpyArray, "n d"]) -> "FaissNeighbors":
self._index.add(data) # type: ignore[reportGeneralTypeIssues]
self._index.add(data) # type: ignore[reportCallIssue]
return self

def query(
Expand All @@ -76,7 +76,7 @@ def query_batch(
points: Float[NumpyArray, "m d"],
n_neighbors: int,
) -> tuple[Int64[NumpyArray, "m {n_neighbors}"], Float32[NumpyArray, "m {n_neighbors}"]]:
dist, idx = self._index.search(points, n_neighbors) # type: ignore[reportGeneralTypeIssues]
dist, idx = self._index.search(points, n_neighbors) # type: ignore[reportCallIssue]
return idx, dist

def _create_index(self, dim: int) -> faiss.Index:
Expand Down
4 changes: 2 additions & 2 deletions src/nearness/_scipy.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,9 @@ def query_batch(
points: Float[NumpyArray, "m d"],
n_neighbors: int,
) -> tuple[Int64[NumpyArray, "m {n_neighbors}"], Float[NumpyArray, "m {n_neighbors}"]]:
distance = cdist(
distance = cdist( # type: ignore[reportCallIssue]
points,
self._data, # type: ignore[reportGeneralTypeIssues]
self._data, # type: ignore[reportArgumentType]
metric=self.parameters.metric,
**self.parameters.metric_args,
)
Expand Down

0 comments on commit a64b7e2

Please sign in to comment.