From 9fd61e058b7517f95fc2ca22ff2f2b5c73afa794 Mon Sep 17 00:00:00 2001 From: Connor Shorten Date: Thu, 8 Feb 2024 22:35:53 -0500 Subject: [PATCH] Allow custom text key to WeaviateRM - resolve issue #359 --- dspy/retrieve/weaviate_rm.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/dspy/retrieve/weaviate_rm.py b/dspy/retrieve/weaviate_rm.py index 0a82e2f5a..7a8f3f7d7 100644 --- a/dspy/retrieve/weaviate_rm.py +++ b/dspy/retrieve/weaviate_rm.py @@ -31,7 +31,9 @@ class WeaviateRM(dspy.Retrieve): llm = dspy.OpenAI(model="gpt-3.5-turbo") weaviate_client = weaviate.Client("your-path-here") - retriever_model = WeaviateRM("my_collection_name", weaviate_client=weaviate_client) + retriever_model = WeaviateRM(weaviate_collection_name="my_collection_name", + weaviate_collection_text_key="content", + weaviate_client=weaviate_client) dspy.settings.configure(lm=llm, rm=retriever_model) ``` @@ -44,11 +46,12 @@ class WeaviateRM(dspy.Retrieve): def __init__(self, weaviate_collection_name: str, weaviate_client: weaviate.Client, - k: int = 3 + k: int = 3, + weaviate_collection_text_key: Optional[str] = "content" ): self._weaviate_collection_name = weaviate_collection_name self._weaviate_client = weaviate_client - + self._weaviate_collection_text_key = weaviate_collection_text_key super().__init__(k=k) def forward(self, query_or_queries: Union[str, List[str]], k: Optional[int]) -> dspy.Prediction: @@ -71,13 +74,13 @@ def forward(self, query_or_queries: Union[str, List[str]], k: Optional[int]) -> passages = [] for query in queries: results = self._weaviate_client.query\ - .get(self._weaviate_collection_name, ["content"])\ + .get(self._weaviate_collection_name, [self._weaviate_collection_text_key])\ .with_hybrid(query=query)\ .with_limit(k)\ .do() results = results["data"]["Get"][self._weaviate_collection_name] - parsed_results = [result["content"] for result in results] + parsed_results = [result[self._weaviate_collection_text_key] for result in results] passages.extend(dotdict({"long_text": d}) for d in parsed_results) return passages