Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make embedders deseralize to correct type #927

Merged
merged 3 commits into from
Feb 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 38 additions & 3 deletions meilisearch/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,16 @@
from meilisearch.config import Config
from meilisearch.errors import version_error_hint_message
from meilisearch.models.document import Document, DocumentsResults
from meilisearch.models.index import Embedders, Faceting, IndexStats, Pagination, TypoTolerance
from meilisearch.models.index import (
Embedders,
Faceting,
HuggingFaceEmbedder,
IndexStats,
OpenAiEmbedder,
Pagination,
TypoTolerance,
UserProvidedEmbedder,
)
from meilisearch.models.task import Task, TaskInfo, TaskResults
from meilisearch.task import TaskHandler

Expand Down Expand Up @@ -865,7 +874,23 @@ def get_settings(self) -> Dict[str, Any]:
MeilisearchApiError
An error containing details about why Meilisearch can't process your request. Meilisearch error codes are described here: https://www.meilisearch.com/docs/reference/errors/error_codes#meilisearch-errors
"""
return self.http.get(f"{self.config.paths.index}/{self.uid}/{self.config.paths.setting}")
settings = self.http.get(
f"{self.config.paths.index}/{self.uid}/{self.config.paths.setting}"
)

if settings.get("embedders"):
embedders: dict[str, OpenAiEmbedder | HuggingFaceEmbedder | UserProvidedEmbedder] = {}
for k, v in settings["embedders"].items():
if v.get("source") == "openAi":
embedders[k] = OpenAiEmbedder(**v)
elif v.get("source") == "huggingFace":
embedders[k] = HuggingFaceEmbedder(**v)
else:
embedders[k] = UserProvidedEmbedder(**v)

settings["embedders"] = embedders

return settings

def update_settings(self, body: Mapping[str, Any]) -> TaskInfo:
"""Update settings of the index.
Expand Down Expand Up @@ -1777,7 +1802,17 @@ def get_embedders(self) -> Embedders | None:
if not response:
return None

return Embedders(embedders=response)
embedders: dict[str, OpenAiEmbedder | HuggingFaceEmbedder | UserProvidedEmbedder] = {}
for k, v in response.items():
print(v.get("source"))
if v.get("source") == "openAi":
embedders[k] = OpenAiEmbedder(**v)
elif v.get("source") == "huggingFace":
embedders[k] = HuggingFaceEmbedder(**v)
else:
embedders[k] = UserProvidedEmbedder(**v)

return Embedders(embedders=embedders)

def update_embedders(self, body: Union[Mapping[str, Any], None]) -> TaskInfo:
"""Update embedders of the index.
Expand Down
8 changes: 4 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import meilisearch
from meilisearch.errors import MeilisearchApiError
from meilisearch.models.index import HuggingFaceEmbedder, OpenAiEmbedder, UserProvidedEmbedder
from tests import common


Expand Down Expand Up @@ -230,8 +231,7 @@ def enable_vector_search():
@fixture
def new_embedders():
return {
"default": {
"source": "userProvided",
"dimensions": 1,
}
"default": UserProvidedEmbedder(dimensions=1).model_dump(by_alias=True),
"open_ai": OpenAiEmbedder().model_dump(by_alias=True),
"hugging_face": HuggingFaceEmbedder().model_dump(by_alias=True),
}
40 changes: 28 additions & 12 deletions tests/settings/test_settings.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,17 @@
NEW_SETTINGS = {
"rankingRules": ["typo", "words"],
"searchableAttributes": ["title", "overview"],
}
# pylint: disable=redefined-outer-name
import pytest

from meilisearch.models.index import HuggingFaceEmbedder, OpenAiEmbedder, UserProvidedEmbedder


@pytest.fixture
def new_settings(new_embedders):
return {
"rankingRules": ["typo", "words"],
"searchableAttributes": ["title", "overview"],
"embedders": new_embedders,
}


DEFAULT_RANKING_RULES = ["words", "typo", "proximity", "attribute", "sort", "exactness"]

Expand Down Expand Up @@ -31,36 +41,41 @@ def test_get_settings_default(empty_index):
assert response["synonyms"] == {}


def test_update_settings(empty_index):
@pytest.mark.usefixtures("enable_vector_search")
def test_update_settings(new_settings, empty_index):
"""Tests updating some settings."""
index = empty_index()
response = index.update_settings(NEW_SETTINGS)
response = index.update_settings(new_settings)
update = index.wait_for_task(response.task_uid)
assert update.status == "succeeded"
response = index.get_settings()
for rule in NEW_SETTINGS["rankingRules"]:
for rule in new_settings["rankingRules"]:
assert rule in response["rankingRules"]
assert response["distinctAttribute"] is None
for attribute in NEW_SETTINGS["searchableAttributes"]:
for attribute in new_settings["searchableAttributes"]:
assert attribute in response["searchableAttributes"]
assert response["displayedAttributes"] == ["*"]
assert response["stopWords"] == []
assert response["synonyms"] == {}
assert isinstance(response["embedders"]["default"], UserProvidedEmbedder)
assert isinstance(response["embedders"]["open_ai"], OpenAiEmbedder)
assert isinstance(response["embedders"]["hugging_face"], HuggingFaceEmbedder)


def test_reset_settings(empty_index):
@pytest.mark.usefixtures("enable_vector_search")
def test_reset_settings(new_settings, empty_index):
"""Tests resetting all the settings to their default value."""
index = empty_index()
# Update settings first
response = index.update_settings(NEW_SETTINGS)
response = index.update_settings(new_settings)
update = index.wait_for_task(response.task_uid)
assert update.status == "succeeded"
# Check the settings have been correctly updated
response = index.get_settings()
for rule in NEW_SETTINGS["rankingRules"]:
for rule in new_settings["rankingRules"]:
assert rule in response["rankingRules"]
assert response["distinctAttribute"] is None
for attribute in NEW_SETTINGS["searchableAttributes"]:
for attribute in new_settings["searchableAttributes"]:
assert attribute in response["searchableAttributes"]
assert response["displayedAttributes"] == ["*"]
assert response["stopWords"] == []
Expand All @@ -80,3 +95,4 @@ def test_reset_settings(empty_index):
assert response["searchableAttributes"] == ["*"]
assert response["stopWords"] == []
assert response["synonyms"] == {}
assert response.get("embedders") is None
19 changes: 13 additions & 6 deletions tests/settings/test_settings_embedders.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# pylint: disable=redefined-outer-name
import pytest

from meilisearch.models.index import Embedders
from meilisearch.models.index import HuggingFaceEmbedder, OpenAiEmbedder, UserProvidedEmbedder


@pytest.mark.usefixtures("enable_vector_search")
Expand All @@ -19,7 +20,9 @@ def test_update_embedders_with_user_provided_source(new_embedders, empty_index):
update = index.wait_for_task(response_update.task_uid)
response_get = index.get_embedders()
assert update.status == "succeeded"
assert response_get == Embedders(embedders=new_embedders)
assert isinstance(response_get.embedders["default"], UserProvidedEmbedder)
assert isinstance(response_get.embedders["open_ai"], OpenAiEmbedder)
assert isinstance(response_get.embedders["hugging_face"], HuggingFaceEmbedder)


@pytest.mark.usefixtures("enable_vector_search")
Expand All @@ -30,15 +33,19 @@ def test_reset_embedders(new_embedders, empty_index):
# Update the settings
response_update = index.update_embedders(new_embedders)
update1 = index.wait_for_task(response_update.task_uid)
assert update1.status == "succeeded"
# Get the setting after update
response_get = index.get_embedders()
assert isinstance(response_get.embedders["default"], UserProvidedEmbedder)
assert isinstance(response_get.embedders["open_ai"], OpenAiEmbedder)
assert isinstance(response_get.embedders["hugging_face"], HuggingFaceEmbedder)
# Reset the setting
response_reset = index.reset_embedders()
update2 = index.wait_for_task(response_reset.task_uid)
# Get the setting after reset
response_last = index.get_embedders()

assert update1.status == "succeeded"
assert response_get == Embedders(embedders=new_embedders)
assert update2.status == "succeeded"
assert isinstance(response_get.embedders["default"], UserProvidedEmbedder)
assert isinstance(response_get.embedders["open_ai"], OpenAiEmbedder)
assert isinstance(response_get.embedders["hugging_face"], HuggingFaceEmbedder)
response_last = index.get_embedders()
assert response_last is None