Skip to content

Commit

Permalink
🐛 Fix initialization of fitted attribute (#10)
Browse files Browse the repository at this point in the history
* 🐛 Fix initialization of fitted attribute
  • Loading branch information
davnn authored May 1, 2024
1 parent ee77dfd commit 887379a
Show file tree
Hide file tree
Showing 7 changed files with 45 additions and 12 deletions.
6 changes: 6 additions & 0 deletions .github/workflows/check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@ on:

jobs:
check:
permissions:
# Give the default GITHUB_TOKEN write permission to commit and push the
# added or changed files to the repository.
contents: write

strategy:
fail-fast: false
matrix:
Expand All @@ -18,3 +23,4 @@ jobs:
os: ${{ matrix.os }}
python: ${{ matrix.python }}
command: task check
push: true
10 changes: 10 additions & 0 deletions .github/workflows/setup.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ on:
command:
required: true
type: string
push:
required: false
type: boolean

env:
POETRY_PYPI_TOKEN_PYPI: ${{ secrets.PYPI_TOKEN }}
Expand Down Expand Up @@ -55,3 +58,10 @@ jobs:
- name: Command
run: ${{ inputs.command }}
shell: bash

- name: Push
if: ${{ inputs.push && inputs.os == 'ubuntu-latest' }}
# Commit all changed files back to the repository
uses: stefanzweifel/git-auto-commit-action@v5
with:
commit_message: ":white_check_mark: automated change [skip ci]"
6 changes: 3 additions & 3 deletions assets/coverage.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ build-backend = "poetry.core.masonry.api"

[tool.poetry]
name = "nearness"
version = "0.1.0"
version = "0.1.1"
description = "An easy-to-use interface for (approximate) nearest neighbors algorithms."
readme = "README.md"
authors = ["David Muhr <[email protected]>"]
Expand Down
2 changes: 1 addition & 1 deletion src/nearness/_annoy.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(
index = AnnoyIndex(self.parameters.load_index_dim, self.parameters.metric)
index.load(str(path))
self._index = index
self.is_fitted = True
self.__fitted__ = True

@overload
def fit(self, data: Iterable[Real[NumpyArray, "d"]]) -> "AnnoyNeighbors": ...
Expand Down
19 changes: 16 additions & 3 deletions src/nearness/_base/_neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,18 @@ def __call__(cls, *_: Any, **kwargs: Any) -> "NearestNeighbors":
"methods_require_fit",
partial(cls._check_callback, obj=obj),
)
obj.__fitted__ = False
if not hasattr(obj, "__fitted__"):
msg = (
f"Instantiated {obj}, but missing the '__fitted__' attribute, which is automatically set to False in "
f"'NearestNeighbors.__init__', did you forget to call 'super().__init__()' in the '__init__' of "
f"{obj}? Assuming that '__fitted__' is 'False'."
)
warn(msg, stacklevel=1)
obj.__fitted__ = False
if not obj.__fitted__:
# __fitted__ might be true if the index is pre-loaded in the ``__init__``.
cls._wrap_check_method(obj)
cls._wrap_fit_method(obj)
cls._wrap_check_method(obj)
del cls._parameters_
del cls._config_
return obj
Expand Down Expand Up @@ -198,6 +207,10 @@ class NearestNeighbors(metaclass=NearestNeighborsMeta):
and returning floating-point distances of equal type as output.
"""

def __init__(self) -> None:
super().__init__()
self.__fitted__ = False

@abstractmethod
def fit(self, data: np.ndarray) -> Self:
"""Learn an index structure based on a matrix of points.
Expand Down Expand Up @@ -345,7 +358,7 @@ def is_fitted(self, value: bool) -> None:
NearestNeighbors._wrap_check_method(self)

# this variable is initialized in the metaclass
self.__fitted__ = value # type: ignore[reportUninitializedInstanceVariable]
self.__fitted__ = value

@property
def config(self) -> "Config":
Expand Down
12 changes: 8 additions & 4 deletions tests/test_nearness.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,6 @@ def create_instance():

# Get the current thread id
thread_id = threading.current_thread().ident
original_thread_ids.add(thread_id)

# Set the thread id on the model
model = Model(thread_id=thread_id)
Expand Down Expand Up @@ -398,14 +397,16 @@ def test_keyword_only():
with pytest.raises(InvalidSignatureError):

class N(NearestNeighbors):
def __init__(self, a): ...
def __init__(self, a):
super().__init__()


def test_warn_check():
class ModelNoConfig(NearestNeighbors):
no_method = True

def __init__(self): ...
def __init__(self):
super().__init__()

def fit(self, data): ...

Expand All @@ -415,6 +416,7 @@ class ModelWrongAttribute(NearestNeighbors):
no_method = True

def __init__(self):
super().__init__()
self.config.methods_require_fit = self.config.methods_require_fit | {"no_method"}

def fit(self, data): ...
Expand All @@ -423,6 +425,7 @@ def query(self, point, n_neighbors): ...

class ModelMissingAttribute(NearestNeighbors):
def __init__(self):
super().__init__()
self.config.methods_require_fit = self.config.methods_require_fit | {"missing_method"}

def fit(self, data): ...
Expand All @@ -446,7 +449,8 @@ def query(self, point, n_neighbors): ...

def test_check_attribute():
class Model(NearestNeighbors):
def __init__(self): ...
def __init__(self):
super().__init__()

def fit(self, data):
return self
Expand Down

0 comments on commit 887379a

Please sign in to comment.