From 027fb92d0c4bcade5ca50e19fcbdc9c441544b01 Mon Sep 17 00:00:00 2001 From: Matthew Billman Date: Sun, 25 Feb 2024 17:55:02 -0500 Subject: [PATCH] add k kwarg to signature for YouRM, and updated the parent dspy.Retrieve signature to match --- dspy/retrieve/retrieve.py | 20 ++++++++++---------- dspy/retrieve/you_rm.py | 10 +++++++--- 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/dspy/retrieve/retrieve.py b/dspy/retrieve/retrieve.py index 4f61ba8ae..910858241 100644 --- a/dspy/retrieve/retrieve.py +++ b/dspy/retrieve/retrieve.py @@ -1,3 +1,5 @@ +from typing import List, Union, Optional + import dsp import random @@ -13,10 +15,10 @@ class Retrieve(Parameter): def __init__(self, k=3): self.stage = random.randbytes(8).hex() self.k = k - + def reset(self): pass - + def dump_state(self): state_keys = ["k"] return {k: getattr(self, k) for k in state_keys} @@ -24,20 +26,18 @@ def dump_state(self): def load_state(self, state): for name, value in state.items(): setattr(self, name, value) - + def __call__(self, *args, **kwargs): return self.forward(*args, **kwargs) - - def forward(self, query_or_queries): + + def forward(self, query_or_queries: Union[str, List[str]], k: Optional[int] = None) -> Prediction: queries = [query_or_queries] if isinstance(query_or_queries, str) else query_or_queries queries = [query.strip().split('\n')[0].strip() for query in queries] - # print(queries) # TODO: Consider removing any quote-like markers that surround the query too. - - passages = dsp.retrieveEnsemble(queries, k=self.k) + k = k if k is not None else self.k + passages = dsp.retrieveEnsemble(queries, k=k) return Prediction(passages=passages) - -# TODO: Consider doing Prediction.from_completions with the individual sets of passages (per query) too. \ No newline at end of file +# TODO: Consider doing Prediction.from_completions with the individual sets of passages (per query) too. diff --git a/dspy/retrieve/you_rm.py b/dspy/retrieve/you_rm.py index 4d5273cc4..a5f69d9f9 100644 --- a/dspy/retrieve/you_rm.py +++ b/dspy/retrieve/you_rm.py @@ -2,7 +2,7 @@ import os import requests -from typing import Union, List +from typing import Union, List, Optional class YouRM(dspy.Retrieve): @@ -15,15 +15,19 @@ def __init__(self, ydc_api_key=None, k=3): else: self.ydc_api_key = os.environ["YDC_API_KEY"] - def forward(self, query_or_queries: Union[str, List[str]]) -> dspy.Prediction: + def forward(self, query_or_queries: Union[str, List[str]], k: Optional[int] = None) -> dspy.Prediction: """Search with You.com for self.k top passages for query or queries Args: query_or_queries (Union[str, List[str]]): The query or queries to search for. + k (Optional[int]): The number of context strings to return, if not already specified in self.k Returns: dspy.Prediction: An object containing the retrieved passages. """ + + k = k if k is not None else self.k + queries = ( [query_or_queries] if isinstance(query_or_queries, str) @@ -36,7 +40,7 @@ def forward(self, query_or_queries: Union[str, List[str]]) -> dspy.Prediction: f"https://api.ydc-index.io/search?query={query}", headers=headers, ).json() - for hit in results["hits"][:self.k]: + for hit in results["hits"][:k]: for snippet in hit["snippets"]: docs.append(snippet) return dspy.Prediction(passages=docs)