diff --git a/docs/docs/integrations/document_transformers/jina_rerank.ipynb b/docs/docs/integrations/document_transformers/jina_rerank.ipynb new file mode 100644 index 0000000000000..beebf1ccb3255 --- /dev/null +++ b/docs/docs/integrations/document_transformers/jina_rerank.ipynb @@ -0,0 +1,254 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "f6ff09ab-c736-4a18-a717-563b4e29d22d", + "metadata": {}, + "source": [ + "# Jina Reranker" + ] + }, + { + "cell_type": "markdown", + "id": "1288789a-4c30-4fc3-90c7-dd1741a2550b", + "metadata": {}, + "source": [ + "This notebook shows how to use Jina Reranker for document compression and retrieval." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a0e4d52e-3968-4f8b-9865-a886f27e5feb", + "metadata": {}, + "outputs": [], + "source": [ + "%pip install -qU langchain langchain-openai langchain-community langchain-text-splitters langchainhub\n", + "\n", + "%pip install --upgrade --quiet faiss\n", + "\n", + "# OR (depending on Python version)\n", + "\n", + "%pip install --upgrade --quiet faiss_cpu" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d1fc07a6-8e01-4aa5-8ed4-ca2b0bfca70c", + "metadata": {}, + "outputs": [], + "source": [ + "# Helper function for printing docs\n", + "\n", + "\n", + "def pretty_print_docs(docs):\n", + " print(\n", + " f\"\\n{'-' * 100}\\n\".join(\n", + " [f\"Document {i+1}:\\n\\n\" + d.page_content for i, d in enumerate(docs)]\n", + " )\n", + " )" + ] + }, + { + "cell_type": "markdown", + "id": "d8ec4823-fdc1-4339-8a25-da598a1e2a4c", + "metadata": {}, + "source": [ + "## Set up the base vector store retriever" + ] + }, + { + "cell_type": "markdown", + "id": "9db25269-e798-496f-8fb9-2bb280735118", + "metadata": {}, + "source": [ + "Let's start by initializing a simple vector store retriever and storing the 2023 State of the Union speech (in chunks). We can set up the retriever to retrieve a high number (20) of docs." + ] + }, + { + "cell_type": "markdown", + "id": "ce01a2b5-d7f4-4902-9156-9a3a86704f40", + "metadata": {}, + "source": [ + "##### Set the Jina and OpenAI API keys" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6692d5c5-c84a-4d42-8dd8-5ce90ff56d20", + "metadata": {}, + "outputs": [], + "source": [ + "import getpass\n", + "import os\n", + "\n", + "os.environ[\"OPENAI_API_KEY\"] = getpass.getpass()\n", + "os.environ[\"JINA_API_KEY\"] = getpass.getpass()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "981159af-fa3c-4f75-adb4-1a4de1950f2f", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_community.document_loaders import TextLoader\n", + "from langchain_community.embeddings import JinaEmbeddings\n", + "from langchain_community.vectorstores import FAISS\n", + "from langchain_text_splitters import RecursiveCharacterTextSplitter\n", + "\n", + "documents = TextLoader(\n", + " \"../../modules/state_of_the_union.txt\",\n", + ").load()\n", + "text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)\n", + "texts = text_splitter.split_documents(documents)\n", + "\n", + "embedding = JinaEmbeddings(model_name=\"jina-embeddings-v2-base-en\")\n", + "retriever = FAISS.from_documents(texts, embedding).as_retriever(search_kwargs={\"k\": 20})\n", + "\n", + "query = \"What did the president say about Ketanji Brown Jackson\"\n", + "docs = retriever.get_relevant_documents(query)\n", + "pretty_print_docs(docs)" + ] + }, + { + "cell_type": "markdown", + "id": "b5a514b7-027a-4dd4-9cfc-63fb4d50aa66", + "metadata": {}, + "source": [ + "## Doing reranking with JinaRerank" + ] + }, + { + "cell_type": "markdown", + "id": "bdd9e0ca-d728-42cb-88ad-459fb8a56b33", + "metadata": {}, + "source": [ + "Now let's wrap our base retriever with a ContextualCompressionRetriever, using Jina Reranker as a compressor." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3000019e-cc0d-4365-91d0-72247ee4d624", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.retrievers import ContextualCompressionRetriever\n", + "from langchain_community.document_compressors import JinaRerank\n", + "\n", + "compressor = JinaRerank()\n", + "compression_retriever = ContextualCompressionRetriever(\n", + " base_compressor=compressor, base_retriever=retriever\n", + ")\n", + "\n", + "compressed_docs = compression_retriever.get_relevant_documents(\n", + " \"What did the president say about Ketanji Jackson Brown\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f314f74c-48a9-4243-8d3c-2b7f820e1e40", + "metadata": {}, + "outputs": [], + "source": [ + "pretty_print_docs(compressed_docs)" + ] + }, + { + "cell_type": "markdown", + "id": "87164f04-194b-4138-8d94-f179f6f34a31", + "metadata": {}, + "source": [ + "## QA reranking with Jina Reranker" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "2b4ab60b-5a26-4cfb-9b58-3dc2d83b772b", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "================================\u001b[1m System Message \u001b[0m================================\n", + "\n", + "Answer any use questions based solely on the context below:\n", + "\n", + "\n", + "\u001b[33;1m\u001b[1;3m{context}\u001b[0m\n", + "\n", + "\n", + "=============================\u001b[1m Messages Placeholder \u001b[0m=============================\n", + "\n", + "\u001b[33;1m\u001b[1;3m{chat_history}\u001b[0m\n", + "\n", + "================================\u001b[1m Human Message \u001b[0m=================================\n", + "\n", + "\u001b[33;1m\u001b[1;3m{input}\u001b[0m\n" + ] + } + ], + "source": [ + "from langchain import hub\n", + "from langchain.chains import create_retrieval_chain\n", + "from langchain.chains.combine_documents import create_stuff_documents_chain\n", + "\n", + "retrieval_qa_chat_prompt = hub.pull(\"langchain-ai/retrieval-qa-chat\")\n", + "retrieval_qa_chat_prompt.pretty_print()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "72af3eb3-b644-4b5f-bf5f-f1dc43c96882", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_openai import ChatOpenAI\n", + "\n", + "llm = ChatOpenAI(model=\"gpt-3.5-turbo-0125\", temperature=0)\n", + "combine_docs_chain = create_stuff_documents_chain(llm, retrieval_qa_chat_prompt)\n", + "chain = create_retrieval_chain(compression_retriever, combine_docs_chain)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "126401a7-c545-4de0-92dc-e9bc1001a6ba", + "metadata": {}, + "outputs": [], + "source": [ + "chain.invoke({\"input\": query})" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "poetry-venv-2", + "language": "python", + "name": "poetry-venv-2" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.1" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/libs/community/langchain_community/document_compressors/__init__.py b/libs/community/langchain_community/document_compressors/__init__.py index 4b7cb93e9429a..ca9822d4baea2 100644 --- a/libs/community/langchain_community/document_compressors/__init__.py +++ b/libs/community/langchain_community/document_compressors/__init__.py @@ -2,6 +2,9 @@ from typing import TYPE_CHECKING, Any if TYPE_CHECKING: + from langchain_community.document_compressors.jina_rerank import ( + JinaRerank, # noqa: F401 + ) from langchain_community.document_compressors.llmlingua_filter import ( LLMLinguaCompressor, # noqa: F401 ) @@ -14,6 +17,7 @@ _module_lookup = { "LLMLinguaCompressor": "langchain_community.document_compressors.llmlingua_filter", "OpenVINOReranker": "langchain_community.document_compressors.openvino_rerank", + "JinaRerank": "langchain_community.document_compressors.jina_rerank", } diff --git a/libs/community/langchain_community/document_compressors/jina_rerank.py b/libs/community/langchain_community/document_compressors/jina_rerank.py new file mode 100644 index 0000000000000..0e1f40a822321 --- /dev/null +++ b/libs/community/langchain_community/document_compressors/jina_rerank.py @@ -0,0 +1,125 @@ +from __future__ import annotations + +from copy import deepcopy +from typing import Any, Dict, List, Optional, Sequence, Union + +import requests +from langchain_core.callbacks import Callbacks +from langchain_core.documents import BaseDocumentCompressor, Document +from langchain_core.pydantic_v1 import Extra, root_validator +from langchain_core.utils import get_from_dict_or_env + +JINA_API_URL: str = "https://api.jina.ai/v1/rerank" + + +class JinaRerank(BaseDocumentCompressor): + """Document compressor that uses `Jina Rerank API`.""" + + session: Any = None + """Requests session to communicate with API.""" + top_n: Optional[int] = 3 + """Number of documents to return.""" + model: str = "jina-reranker-v1-base-en" + """Model to use for reranking.""" + jina_api_key: Optional[str] = None + """Jina API key. Must be specified directly or via environment variable + JINA_API_KEY.""" + user_agent: str = "langchain" + """Identifier for the application making the request.""" + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + arbitrary_types_allowed = True + + @root_validator(pre=True) + def validate_environment(cls, values: Dict) -> Dict: + """Validate that api key exists in environment.""" + jina_api_key = get_from_dict_or_env(values, "jina_api_key", "JINA_API_KEY") + user_agent = values.get("user_agent", "langchain") + session = requests.Session() + session.headers.update( + { + "Authorization": f"Bearer {jina_api_key}", + "Accept-Encoding": "identity", + "Content-type": "application/json", + "user-agent": user_agent, + } + ) + values["session"] = session + return values + + def rerank( + self, + documents: Sequence[Union[str, Document, dict]], + query: str, + *, + model: Optional[str] = None, + top_n: Optional[int] = -1, + max_chunks_per_doc: Optional[int] = None, + ) -> List[Dict[str, Any]]: + """Returns an ordered list of documents ordered by their relevance to the provided query. + + Args: + query: The query to use for reranking. + documents: A sequence of documents to rerank. + model: The model to use for re-ranking. Default to self.model. + top_n : The number of results to return. If None returns all results. + Defaults to self.top_n. + max_chunks_per_doc : The maximum number of chunks derived from a document. + """ # noqa: E501 + if len(documents) == 0: # to avoid empty api call + return [] + docs = [ + doc.page_content if isinstance(doc, Document) else doc for doc in documents + ] + model = model or self.model + top_n = top_n if (top_n is None or top_n > 0) else self.top_n + data = { + "query": query, + "documents": docs, + "model": model, + "top_n": top_n, + } + + resp = self.session.post( + JINA_API_URL, + json=data, + ).json() + + if "results" not in resp: + raise RuntimeError(resp["detail"]) + + results = resp["results"] + result_dicts = [] + for res in results: + result_dicts.append( + {"index": res["index"], "relevance_score": res["relevance_score"]} + ) + return result_dicts + + def compress_documents( + self, + documents: Sequence[Document], + query: str, + callbacks: Optional[Callbacks] = None, + ) -> Sequence[Document]: + """ + Compress documents using Jina's Rerank API. + + Args: + documents: A sequence of documents to compress. + query: The query to use for compressing the documents. + callbacks: Callbacks to run during the compression process. + + Returns: + A sequence of compressed documents. + """ + compressed = [] + for res in self.rerank(documents, query): + doc = documents[res["index"]] + doc_copy = Document(doc.page_content, metadata=deepcopy(doc.metadata)) + doc_copy.metadata["relevance_score"] = res["relevance_score"] + compressed.append(doc_copy) + return compressed diff --git a/libs/community/tests/unit_tests/document_compressors/test_imports.py b/libs/community/tests/unit_tests/document_compressors/test_imports.py index 2c928857abb24..d6451cdd88ed3 100644 --- a/libs/community/tests/unit_tests/document_compressors/test_imports.py +++ b/libs/community/tests/unit_tests/document_compressors/test_imports.py @@ -1,6 +1,6 @@ from langchain_community.document_compressors import __all__, _module_lookup -EXPECTED_ALL = ["LLMLinguaCompressor", "OpenVINOReranker"] +EXPECTED_ALL = ["LLMLinguaCompressor", "OpenVINOReranker", "JinaRerank"] def test_all_imports() -> None: