diff --git a/dsp/modules/google.py b/dsp/modules/google.py new file mode 100644 index 000000000..581c98414 --- /dev/null +++ b/dsp/modules/google.py @@ -0,0 +1,104 @@ +import math +from typing import Any, Optional +import backoff + +from dsp.modules.lm import LM + +try: + import google.generativeai as genai +except ImportError: + google_api_error = Exception + print("Not loading Google because it is not installed.") + +def backoff_hdlr(details): + """Handler from https://pypi.org/project/backoff/""" + print( + "Backing off {wait:0.1f} seconds after {tries} tries " + "calling function {target} with kwargs " + "{kwargs}".format(**details) + ) + + +def giveup_hdlr(details): + """wrapper function that decides when to give up on retry""" + if "rate limits" in details.message: + return False + return True + + +class Google(LM): + """Wrapper around Google's API. + + Currently supported models include `gemini-pro-1.0`. + """ + + def __init__( + self, + model: str = "gemini-pro-1.0", + api_key: Optional[str] = None, + **kwargs + ): + """ + Parameters + ---------- + model : str + Which pre-trained model from Google to use? + Choices are [`gemini-pro-1.0`] + api_key : str + The API key for Google. + It can be obtained from https://cloud.google.com/generative-ai-studio + **kwargs: dict + Additional arguments to pass to the API provider. + """ + super().__init__(model) + self.google = genai.configure(api_key=self.api_key) + self.provider = "google" + self.kwargs = { + "model_name": model, + "temperature": 0.0 if "temperature" not in kwargs else kwargs["temperature"], + "max_output_tokens": 2048, + "top_p": 1, + "top_k": 1, + **kwargs + } + + self.history: list[dict[str, Any]] = [] + + def basic_request(self, prompt: str, **kwargs): + raw_kwargs = kwargs + kwargs = { + **self.kwargs, + "prompt": prompt, + **kwargs, + } + response = self.co.generate(**kwargs) + + history = { + "prompt": prompt, + "response": response, + "kwargs": kwargs, + "raw_kwargs": raw_kwargs, + } + self.history.append(history) + + return response + + @backoff.on_exception( + backoff.expo, + (google_api_error), + max_time=1000, + on_backoff=backoff_hdlr, + giveup=giveup_hdlr, + ) + def request(self, prompt: str, **kwargs): + """Handles retrieval of completions from Google whilst handling API errors""" + return self.basic_request(prompt, **kwargs) + + def __call__( + self, + prompt: str, + only_completed: bool = True, + return_sorted: bool = False, + **kwargs + ): + return self.request(prompt, **kwargs)