diff --git a/gptcache/adapter/adapter.py b/gptcache/adapter/adapter.py index e167ec24..1d85aadc 100644 --- a/gptcache/adapter/adapter.py +++ b/gptcache/adapter/adapter.py @@ -27,6 +27,7 @@ def adapt(llm_handler, cache_data_convert, update_cache_callback, *args, **kwarg chat_cache = kwargs.pop("cache_obj", cache) session = kwargs.pop("session", None) require_object_store = kwargs.pop("require_object_store", False) + # metadata = kwargs.pop("metadata", {}) if require_object_store: assert chat_cache.data_manager.o, "Object store is required for adapter." if not chat_cache.has_init: @@ -91,6 +92,7 @@ def adapt(llm_handler, cache_data_convert, update_cache_callback, *args, **kwarg top_k=kwargs.pop("top_k", 5) if (user_temperature and not user_top_k) else kwargs.pop("top_k", -1), + **kwargs, ) if search_data_list is None: search_data_list = [] @@ -245,7 +247,7 @@ def post_process(): if cache_enable: try: - def update_cache_func(handled_llm_data, question=None): + def update_cache_func(handled_llm_data, question=None, **kwargs): if question is None: question = pre_store_data else: @@ -260,6 +262,7 @@ def update_cache_func(handled_llm_data, question=None): embedding_data, extra_param=context.get("save_func", None), session=session, + **kwargs ) if ( chat_cache.report.op_save.count > 0 @@ -267,7 +270,6 @@ def update_cache_func(handled_llm_data, question=None): == 0 ): chat_cache.flush() - llm_data = update_cache_callback( llm_data, update_cache_func, *args, **kwargs ) @@ -359,6 +361,7 @@ async def aadapt( top_k=kwargs.pop("top_k", 5) if (user_temperature and not user_top_k) else kwargs.pop("top_k", -1), + **kwargs ) if search_data_list is None: search_data_list = [] diff --git a/gptcache/adapter/langchain_models.py b/gptcache/adapter/langchain_models.py index 945d9eb2..7e660139 100644 --- a/gptcache/adapter/langchain_models.py +++ b/gptcache/adapter/langchain_models.py @@ -205,8 +205,9 @@ def generate( callbacks: Callbacks = None, **kwargs, ) -> LLMResult: + print("kwargs inside generate: ",kwargs) self.tmp_args = kwargs - return super().generate(messages, stop=stop, callbacks=callbacks) + return super().generate(messages, stop=stop, callbacks=callbacks, **kwargs) async def agenerate( self, @@ -232,6 +233,7 @@ def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: return self.chat.get_num_tokens_from_messages(messages) def __call__(self, messages: Any, stop: Optional[List[str]] = None, **kwargs): + print("kwargs in __call__: ", kwargs) generation = self.generate([messages], stop=stop, **kwargs).generations[0][0] if isinstance(generation, ChatGeneration): return generation.message diff --git a/gptcache/adapter/openai.py b/gptcache/adapter/openai.py index 6f3d50aa..039d42f0 100644 --- a/gptcache/adapter/openai.py +++ b/gptcache/adapter/openai.py @@ -56,6 +56,7 @@ class ChatCompletion(openai.ChatCompletion, BaseCacheLLM): @classmethod def _llm_handler(cls, *llm_args, **llm_kwargs): try: + _ = llm_kwargs.pop('metadata',{}) return super().create(*llm_args, **llm_kwargs) if cls.llm is None else cls.llm(*llm_args, **llm_kwargs) except openai.OpenAIError as e: raise wrap_error(e) from e @@ -66,7 +67,7 @@ def _update_cache_callback( ): # pylint: disable=unused-argument if not isinstance(llm_data, Iterator): update_cache_func( - Answer(get_message_from_openai_answer(llm_data), DataType.STR) + Answer(get_message_from_openai_answer(llm_data), DataType.STR), **kwargs ) return llm_data else: @@ -76,7 +77,7 @@ def hook_openai_data(it): for item in it: total_answer += get_stream_message_from_openai_answer(item) yield item - update_cache_func(Answer(total_answer, DataType.STR)) + update_cache_func(Answer(total_answer, DataType.STR), **kwargs) return hook_openai_data(llm_data) diff --git a/gptcache/core.py b/gptcache/core.py index 313801e9..b8d88e23 100644 --- a/gptcache/core.py +++ b/gptcache/core.py @@ -87,7 +87,7 @@ def close(): if not os.getenv("IS_CI"): gptcache_log.error(e) - def import_data(self, questions: List[Any], answers: List[Any], session_ids: Optional[List[Optional[str]]] = None) -> None: + def import_data(self, questions: List[Any], answers: List[Any], session_ids: Optional[List[Optional[str]]] = None, **kwargs) -> None: """Import data to GPTCache :param questions: preprocessed question Data @@ -101,6 +101,7 @@ def import_data(self, questions: List[Any], answers: List[Any], session_ids: Opt answers=answers, embedding_datas=[self.embedding_func(question) for question in questions], session_ids=session_ids if session_ids else [None for _ in range(len(questions))], + **kwargs ) def flush(self): diff --git a/gptcache/manager/data_manager.py b/gptcache/manager/data_manager.py index a30ce010..f24fae98 100644 --- a/gptcache/manager/data_manager.py +++ b/gptcache/manager/data_manager.py @@ -1,3 +1,4 @@ +from importlib.metadata import metadata import pickle from abc import abstractmethod, ABCMeta from typing import List, Any, Optional, Union @@ -35,6 +36,7 @@ def import_data( answers: List[Any], embedding_datas: List[Any], session_ids: List[Optional[str]], + **kwargs ): pass @@ -135,6 +137,7 @@ def import_data( answers: List[Any], embedding_datas: List[Any], session_ids: List[Optional[str]], + **kwargs ): if ( len(questions) != len(answers) @@ -271,7 +274,7 @@ def save(self, question, answer, embedding_data, **kwargs): """ session = kwargs.get("session", None) session_id = session.name if session else None - self.import_data([question], [answer], [embedding_data], [session_id]) + self.import_data([question], [answer], [embedding_data], [session_id], **kwargs) def _process_answer_data(self, answers: Union[Answer, List[Answer]]): if isinstance(answers, Answer): @@ -302,6 +305,7 @@ def import_data( answers: List[Answer], embedding_datas: List[Any], session_ids: List[Optional[str]], + **kwargs, ): if ( len(questions) != len(answers) @@ -332,7 +336,7 @@ def import_data( [ VectorData(id=ids[i], data=embedding_data) for i, embedding_data in enumerate(embedding_datas) - ] + ], kwargs=kwargs ) self.eviction_base.put(ids) @@ -367,8 +371,8 @@ def hit_cache_callback(self, res_data, **kwargs): def search(self, embedding_data, **kwargs): embedding_data = normalize(embedding_data) - top_k = kwargs.get("top_k", -1) - return self.v.search(data=embedding_data, top_k=top_k) + top_k = kwargs.pop("top_k", -1) + return self.v.search(data=embedding_data, top_k=top_k, **kwargs) def flush(self): self.s.flush() diff --git a/gptcache/manager/vector_data/base.py b/gptcache/manager/vector_data/base.py index 17b4825a..62173056 100644 --- a/gptcache/manager/vector_data/base.py +++ b/gptcache/manager/vector_data/base.py @@ -9,17 +9,19 @@ class VectorData: id: int data: np.ndarray + account_id: int = '-1' + pipeline: str = '' class VectorBase(ABC): """VectorBase: base vector store interface""" @abstractmethod - def mul_add(self, datas: List[VectorData]): + def mul_add(self, datas: List[VectorData], **kwargs): pass @abstractmethod - def search(self, data: np.ndarray, top_k: int): + def search(self, data: np.ndarray, top_k: int, **kwargs): pass @abstractmethod diff --git a/gptcache/manager/vector_data/chroma.py b/gptcache/manager/vector_data/chroma.py index 2d9eb161..8454d9f7 100644 --- a/gptcache/manager/vector_data/chroma.py +++ b/gptcache/manager/vector_data/chroma.py @@ -46,11 +46,11 @@ def __init__( self._persist_directory = persist_directory self._collection = self._client.get_or_create_collection(name=collection_name) - def mul_add(self, datas: List[VectorData]): + def mul_add(self, datas: List[VectorData], **kwargs): data_array, id_array = map(list, zip(*((data.data.tolist(), str(data.id)) for data in datas))) self._collection.add(embeddings=data_array, ids=id_array) - def search(self, data, top_k: int = -1): + def search(self, data, top_k: int = -1, **kwargs): if self._collection.count() == 0: return [] if top_k == -1: diff --git a/gptcache/manager/vector_data/faiss.py b/gptcache/manager/vector_data/faiss.py index 65643424..8a65f4e7 100644 --- a/gptcache/manager/vector_data/faiss.py +++ b/gptcache/manager/vector_data/faiss.py @@ -30,13 +30,13 @@ def __init__(self, index_file_path, dimension, top_k): if os.path.isfile(index_file_path): self._index = faiss.read_index(index_file_path) - def mul_add(self, datas: List[VectorData]): + def mul_add(self, datas: List[VectorData], **kwargs): data_array, id_array = map(list, zip(*((data.data, data.id) for data in datas))) np_data = np.array(data_array).astype("float32") ids = np.array(id_array) self._index.add_with_ids(np_data, ids) - def search(self, data: np.ndarray, top_k: int = -1): + def search(self, data: np.ndarray, top_k: int = -1, **kwargs): if self._index.ntotal == 0: return None if top_k == -1: diff --git a/gptcache/manager/vector_data/manager.py b/gptcache/manager/vector_data/manager.py index 86616b5b..7d82a85b 100644 --- a/gptcache/manager/vector_data/manager.py +++ b/gptcache/manager/vector_data/manager.py @@ -1,4 +1,5 @@ from gptcache.utils.error import NotFoundError, ParamError +import pinecone TOP_K = 1 @@ -257,6 +258,16 @@ def get(name, **kwargs): flush_interval_sec=flush_interval_sec, index_params=index_params, ) + elif name == "pinecone": + from gptcache.manager.vector_data.pinecone import Pinecone + api_key = kwargs.get("api_key", None) + metric = kwargs.get("metric",'cosine') + environment = kwargs.get("environment",None) + dimension = kwargs.get("dimension", DIMENSION) + top_k: int = kwargs.get("top_k", TOP_K) + index_name = kwargs.get("index_name", "caching") + pinecone.init(api_key=api_key, environment=environment) + vector_base = Pinecone(index_file_path=index_name,dimension=dimension,top_k=top_k, metric=metric) else: raise NotFoundError("vector store", name) return vector_base diff --git a/gptcache/manager/vector_data/pinecone.py b/gptcache/manager/vector_data/pinecone.py new file mode 100644 index 00000000..acb72ac1 --- /dev/null +++ b/gptcache/manager/vector_data/pinecone.py @@ -0,0 +1,64 @@ +from importlib.metadata import metadata +import os +from typing import List +from xml.etree.ElementInclude import include +import numpy as np +from gptcache.manager.vector_data.base import VectorBase, VectorData +import pinecone +import time + +class Pinecone(VectorBase): + """vector store: Pinecone + + :param index_path: the path to Pinecone index, defaults to 'caching'. + :type index_path: str + :param dimension: the dimension of the vector, defaults to 0. + :type dimension: int + :param top_k: the number of the vectors results to return, defaults to 1. + :type top_k: int + """ + + def __init__(self, index_file_path, dimension, top_k, metric): + self._index_file_path = index_file_path + self._dimension = dimension + assert metric=='euclidean' + self.indexes = pinecone.list_indexes() + if index_file_path not in self.indexes: + pinecone.create_index(index_file_path, dimension=dimension, metric=metric) + time.sleep(50) + self.index = pinecone.Index(index_file_path) + self._top_k = top_k + + def mul_add(self, datas: List[VectorData], **kwargs): + metadata = kwargs.get('kwargs').get('kwargs').pop('metadata',{}) + assert metadata!={}, "Please provide the metadata for the following request to process!!" + data_array, id_array= map(list, zip(*((data.data, data.id) for data in datas))) + np_data = np.array(data_array).astype("float32") + ids = np.array(id_array) + upsert_data = [(str(i_d), data.reshape(1,-1).tolist(), {"account_id": int(metadata['account_id']), "pipeline": str(metadata['pipeline'])}) for (i_d,data) in zip(ids,np_data)] + self.index.upsert(upsert_data) + + def search(self, data: np.ndarray, top_k: int = -1, **kwargs): + if self.index.describe_index_stats()['total_vector_count'] == 0: + return None + if top_k == -1: + top_k = self._top_k + metadata = kwargs.get("metadata",{}) + assert metadata!={}, "Please provide metadata for the search query!!" + np_data = np.array(data).astype("float32").reshape(1, -1) + response = self.index.query(vector = np_data.tolist(), top_k = top_k, include_values = False, filter={"account_id": int(metadata["account_id"]), "pipeline": str(metadata["pipeline"])}) #add additional filter + if len(response['matches'])!=0: + dist, ids = [response['matches'][0]['score']], [int(response['matches'][0]['id'])] + return list(zip(dist, ids)) + else: + return None + + def rebuild(self, ids=None): + return True + + def delete(self, ids): + ids_to_remove = np.array(ids) + self.index.delete(ids=ids_to_remove) # add namespace + + def count(self): + return self.index.describe_index_stats()['total_vector_count'] diff --git a/testing-pinecone-chat.py b/testing-pinecone-chat.py new file mode 100644 index 00000000..35379f11 --- /dev/null +++ b/testing-pinecone-chat.py @@ -0,0 +1,58 @@ +import time +from gptcache import cache +from gptcache.adapter import openai +from gptcache.embedding import Onnx +from gptcache.manager import CacheBase, VectorBase, get_data_manager +from gptcache.similarity_evaluation.distance import SearchDistanceEvaluation +import pdb +from gptcache.adapter.langchain_models import LangChainChat +from langchain.chat_models import ChatOpenAI +from langchain.schema import ( + AIMessage, + HumanMessage, + SystemMessage +) + +print("Cache loading.....") + +def get_msg(data, **_): + return data.get("messages")[-1].content + +onnx = Onnx() +### you can uncomment the following lines according to which database you want to use +# data_manager = get_data_manager(CacheBase("sqlite"), VectorBase("pinecone", \ +# dimension=onnx.dimension, index_name='caching', api_key='e0c287dd-b4a3-4600-ad42-5bf792decf19',\ +# environment = 'asia-southeast1-gcp-free', metric='euclidean')) + +data_manager = get_data_manager(CacheBase("sqlite"), VectorBase("faiss", dimension=onnx.dimension)) +# data_manager = get_data_manager(CacheBase("sqlite"), VectorBase("chromadb", dimension=onnx.dimension)) +# data_manager = get_data_manager(CacheBase("sqlite"), VectorBase("docarray", dimension=onnx.dimension)) + +cache.init( + pre_embedding_func=get_msg, + embedding_func=onnx.to_embeddings, + data_manager=data_manager, + similarity_evaluation=SearchDistanceEvaluation(), + ) +cache.set_openai_key() + +chat = LangChainChat(chat=ChatOpenAI(temperature=0.0)) + +questions = [ + "tell me something about chatgpt", + "what is chatgpt?", +] + +metadata = { + 'account_id': '-123', + 'pipeline': 'completion' +} + +if __name__=="__main__": + # pdb.set_trace() + for question in questions: + start_time = time.time() + messages = [HumanMessage(content=question)] + print(chat(messages, metadata=metadata)) + print(f'Question: {question}') + print("Time consuming: {:.2f}s".format(time.time() - start_time)) \ No newline at end of file diff --git a/testing-pinecone.py b/testing-pinecone.py new file mode 100644 index 00000000..9c6c6151 --- /dev/null +++ b/testing-pinecone.py @@ -0,0 +1,59 @@ +import time +from gptcache import cache +from gptcache.adapter import openai +from gptcache.embedding import Onnx +from gptcache.manager import CacheBase, VectorBase, get_data_manager +from gptcache.similarity_evaluation.distance import SearchDistanceEvaluation +import pdb + +def response_text(openai_resp): + return openai_resp['choices'][0]['message']['content'] + +print("Cache loading.....") + +onnx = Onnx() + +### you can uncomment the following lines according to which database you want to use +data_manager = get_data_manager(CacheBase("sqlite"), VectorBase("pinecone", \ + dimension=onnx.dimension, index_name='caching', api_key='e0c287dd-b4a3-4600-ad42-5bf792decf19',\ + environment = 'asia-southeast1-gcp-free', metric='euclidean')) + +# data_manager = get_data_manager(CacheBase("sqlite"), VectorBase("faiss", dimension=onnx.dimension)) +# data_manager = get_data_manager(CacheBase("sqlite"), VectorBase("chromadb", dimension=onnx.dimension)) +# data_manager = get_data_manager(CacheBase("sqlite"), VectorBase("docarray", dimension=onnx.dimension)) + +cache.init( + embedding_func=onnx.to_embeddings, + data_manager=data_manager, + similarity_evaluation=SearchDistanceEvaluation(), + ) +cache.set_openai_key() + +questions = [ + "tell me something about chatgpt", + "what is chatgpt?", +] + +metadata = { + 'account_id': '-123', + 'pipeline': 'completion' +} + +if __name__=="__main__": + # pdb.set_trace() + for question in questions: + start_time = time.time() + response = openai.ChatCompletion.create( + model='gpt-3.5-turbo', + messages=[ + { + 'role': 'user', + 'content': question + } + ], + metadata=metadata + ) + print(f'Question: {question}') + print("Time consuming: {:.2f}s".format(time.time() - start_time)) + print(f'Answer: {response_text(response)}\n') + # print("usage: ",response['usage']) \ No newline at end of file