-
Notifications
You must be signed in to change notification settings - Fork 16.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
community[mionr]: add Jina Reranker in retrievers module (#19406)
- **Description:** Adapt JinaEmbeddings to run with the new Jina AI Rerank API - **Twitter handle:** https://twitter.com/JinaAI_ - [ ] **Add tests and docs**: If you're adding a new integration, please include 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/docs/integrations` directory. - [ ] **Lint and test**: Run `make format`, `make lint` and `make test` from the root of the package(s) you've modified. See contribution guidelines for more: https://python.langchain.com/docs/contributing/ --------- Co-authored-by: Bagatur <[email protected]> Co-authored-by: Bagatur <[email protected]>
- Loading branch information
1 parent
92969d4
commit baefbfb
Showing
4 changed files
with
384 additions
and
1 deletion.
There are no files selected for viewing
254 changes: 254 additions & 0 deletions
254
docs/docs/integrations/document_transformers/jina_rerank.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,254 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"id": "f6ff09ab-c736-4a18-a717-563b4e29d22d", | ||
"metadata": {}, | ||
"source": [ | ||
"# Jina Reranker" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "1288789a-4c30-4fc3-90c7-dd1741a2550b", | ||
"metadata": {}, | ||
"source": [ | ||
"This notebook shows how to use Jina Reranker for document compression and retrieval." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "a0e4d52e-3968-4f8b-9865-a886f27e5feb", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"%pip install -qU langchain langchain-openai langchain-community langchain-text-splitters langchainhub\n", | ||
"\n", | ||
"%pip install --upgrade --quiet faiss\n", | ||
"\n", | ||
"# OR (depending on Python version)\n", | ||
"\n", | ||
"%pip install --upgrade --quiet faiss_cpu" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "d1fc07a6-8e01-4aa5-8ed4-ca2b0bfca70c", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Helper function for printing docs\n", | ||
"\n", | ||
"\n", | ||
"def pretty_print_docs(docs):\n", | ||
" print(\n", | ||
" f\"\\n{'-' * 100}\\n\".join(\n", | ||
" [f\"Document {i+1}:\\n\\n\" + d.page_content for i, d in enumerate(docs)]\n", | ||
" )\n", | ||
" )" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "d8ec4823-fdc1-4339-8a25-da598a1e2a4c", | ||
"metadata": {}, | ||
"source": [ | ||
"## Set up the base vector store retriever" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "9db25269-e798-496f-8fb9-2bb280735118", | ||
"metadata": {}, | ||
"source": [ | ||
"Let's start by initializing a simple vector store retriever and storing the 2023 State of the Union speech (in chunks). We can set up the retriever to retrieve a high number (20) of docs." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "ce01a2b5-d7f4-4902-9156-9a3a86704f40", | ||
"metadata": {}, | ||
"source": [ | ||
"##### Set the Jina and OpenAI API keys" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "6692d5c5-c84a-4d42-8dd8-5ce90ff56d20", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import getpass\n", | ||
"import os\n", | ||
"\n", | ||
"os.environ[\"OPENAI_API_KEY\"] = getpass.getpass()\n", | ||
"os.environ[\"JINA_API_KEY\"] = getpass.getpass()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "981159af-fa3c-4f75-adb4-1a4de1950f2f", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from langchain_community.document_loaders import TextLoader\n", | ||
"from langchain_community.embeddings import JinaEmbeddings\n", | ||
"from langchain_community.vectorstores import FAISS\n", | ||
"from langchain_text_splitters import RecursiveCharacterTextSplitter\n", | ||
"\n", | ||
"documents = TextLoader(\n", | ||
" \"../../modules/state_of_the_union.txt\",\n", | ||
").load()\n", | ||
"text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)\n", | ||
"texts = text_splitter.split_documents(documents)\n", | ||
"\n", | ||
"embedding = JinaEmbeddings(model_name=\"jina-embeddings-v2-base-en\")\n", | ||
"retriever = FAISS.from_documents(texts, embedding).as_retriever(search_kwargs={\"k\": 20})\n", | ||
"\n", | ||
"query = \"What did the president say about Ketanji Brown Jackson\"\n", | ||
"docs = retriever.get_relevant_documents(query)\n", | ||
"pretty_print_docs(docs)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "b5a514b7-027a-4dd4-9cfc-63fb4d50aa66", | ||
"metadata": {}, | ||
"source": [ | ||
"## Doing reranking with JinaRerank" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "bdd9e0ca-d728-42cb-88ad-459fb8a56b33", | ||
"metadata": {}, | ||
"source": [ | ||
"Now let's wrap our base retriever with a ContextualCompressionRetriever, using Jina Reranker as a compressor." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "3000019e-cc0d-4365-91d0-72247ee4d624", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from langchain.retrievers import ContextualCompressionRetriever\n", | ||
"from langchain_community.document_compressors import JinaRerank\n", | ||
"\n", | ||
"compressor = JinaRerank()\n", | ||
"compression_retriever = ContextualCompressionRetriever(\n", | ||
" base_compressor=compressor, base_retriever=retriever\n", | ||
")\n", | ||
"\n", | ||
"compressed_docs = compression_retriever.get_relevant_documents(\n", | ||
" \"What did the president say about Ketanji Jackson Brown\"\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "f314f74c-48a9-4243-8d3c-2b7f820e1e40", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"pretty_print_docs(compressed_docs)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "87164f04-194b-4138-8d94-f179f6f34a31", | ||
"metadata": {}, | ||
"source": [ | ||
"## QA reranking with Jina Reranker" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"id": "2b4ab60b-5a26-4cfb-9b58-3dc2d83b772b", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"================================\u001b[1m System Message \u001b[0m================================\n", | ||
"\n", | ||
"Answer any use questions based solely on the context below:\n", | ||
"\n", | ||
"<context>\n", | ||
"\u001b[33;1m\u001b[1;3m{context}\u001b[0m\n", | ||
"</context>\n", | ||
"\n", | ||
"=============================\u001b[1m Messages Placeholder \u001b[0m=============================\n", | ||
"\n", | ||
"\u001b[33;1m\u001b[1;3m{chat_history}\u001b[0m\n", | ||
"\n", | ||
"================================\u001b[1m Human Message \u001b[0m=================================\n", | ||
"\n", | ||
"\u001b[33;1m\u001b[1;3m{input}\u001b[0m\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"from langchain import hub\n", | ||
"from langchain.chains import create_retrieval_chain\n", | ||
"from langchain.chains.combine_documents import create_stuff_documents_chain\n", | ||
"\n", | ||
"retrieval_qa_chat_prompt = hub.pull(\"langchain-ai/retrieval-qa-chat\")\n", | ||
"retrieval_qa_chat_prompt.pretty_print()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "72af3eb3-b644-4b5f-bf5f-f1dc43c96882", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from langchain_openai import ChatOpenAI\n", | ||
"\n", | ||
"llm = ChatOpenAI(model=\"gpt-3.5-turbo-0125\", temperature=0)\n", | ||
"combine_docs_chain = create_stuff_documents_chain(llm, retrieval_qa_chat_prompt)\n", | ||
"chain = create_retrieval_chain(compression_retriever, combine_docs_chain)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "126401a7-c545-4de0-92dc-e9bc1001a6ba", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"chain.invoke({\"input\": query})" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "poetry-venv-2", | ||
"language": "python", | ||
"name": "poetry-venv-2" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.9.1" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
125 changes: 125 additions & 0 deletions
125
libs/community/langchain_community/document_compressors/jina_rerank.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
from __future__ import annotations | ||
|
||
from copy import deepcopy | ||
from typing import Any, Dict, List, Optional, Sequence, Union | ||
|
||
import requests | ||
from langchain_core.callbacks import Callbacks | ||
from langchain_core.documents import BaseDocumentCompressor, Document | ||
from langchain_core.pydantic_v1 import Extra, root_validator | ||
from langchain_core.utils import get_from_dict_or_env | ||
|
||
JINA_API_URL: str = "https://api.jina.ai/v1/rerank" | ||
|
||
|
||
class JinaRerank(BaseDocumentCompressor): | ||
"""Document compressor that uses `Jina Rerank API`.""" | ||
|
||
session: Any = None | ||
"""Requests session to communicate with API.""" | ||
top_n: Optional[int] = 3 | ||
"""Number of documents to return.""" | ||
model: str = "jina-reranker-v1-base-en" | ||
"""Model to use for reranking.""" | ||
jina_api_key: Optional[str] = None | ||
"""Jina API key. Must be specified directly or via environment variable | ||
JINA_API_KEY.""" | ||
user_agent: str = "langchain" | ||
"""Identifier for the application making the request.""" | ||
|
||
class Config: | ||
"""Configuration for this pydantic object.""" | ||
|
||
extra = Extra.forbid | ||
arbitrary_types_allowed = True | ||
|
||
@root_validator(pre=True) | ||
def validate_environment(cls, values: Dict) -> Dict: | ||
"""Validate that api key exists in environment.""" | ||
jina_api_key = get_from_dict_or_env(values, "jina_api_key", "JINA_API_KEY") | ||
user_agent = values.get("user_agent", "langchain") | ||
session = requests.Session() | ||
session.headers.update( | ||
{ | ||
"Authorization": f"Bearer {jina_api_key}", | ||
"Accept-Encoding": "identity", | ||
"Content-type": "application/json", | ||
"user-agent": user_agent, | ||
} | ||
) | ||
values["session"] = session | ||
return values | ||
|
||
def rerank( | ||
self, | ||
documents: Sequence[Union[str, Document, dict]], | ||
query: str, | ||
*, | ||
model: Optional[str] = None, | ||
top_n: Optional[int] = -1, | ||
max_chunks_per_doc: Optional[int] = None, | ||
) -> List[Dict[str, Any]]: | ||
"""Returns an ordered list of documents ordered by their relevance to the provided query. | ||
Args: | ||
query: The query to use for reranking. | ||
documents: A sequence of documents to rerank. | ||
model: The model to use for re-ranking. Default to self.model. | ||
top_n : The number of results to return. If None returns all results. | ||
Defaults to self.top_n. | ||
max_chunks_per_doc : The maximum number of chunks derived from a document. | ||
""" # noqa: E501 | ||
if len(documents) == 0: # to avoid empty api call | ||
return [] | ||
docs = [ | ||
doc.page_content if isinstance(doc, Document) else doc for doc in documents | ||
] | ||
model = model or self.model | ||
top_n = top_n if (top_n is None or top_n > 0) else self.top_n | ||
data = { | ||
"query": query, | ||
"documents": docs, | ||
"model": model, | ||
"top_n": top_n, | ||
} | ||
|
||
resp = self.session.post( | ||
JINA_API_URL, | ||
json=data, | ||
).json() | ||
|
||
if "results" not in resp: | ||
raise RuntimeError(resp["detail"]) | ||
|
||
results = resp["results"] | ||
result_dicts = [] | ||
for res in results: | ||
result_dicts.append( | ||
{"index": res["index"], "relevance_score": res["relevance_score"]} | ||
) | ||
return result_dicts | ||
|
||
def compress_documents( | ||
self, | ||
documents: Sequence[Document], | ||
query: str, | ||
callbacks: Optional[Callbacks] = None, | ||
) -> Sequence[Document]: | ||
""" | ||
Compress documents using Jina's Rerank API. | ||
Args: | ||
documents: A sequence of documents to compress. | ||
query: The query to use for compressing the documents. | ||
callbacks: Callbacks to run during the compression process. | ||
Returns: | ||
A sequence of compressed documents. | ||
""" | ||
compressed = [] | ||
for res in self.rerank(documents, query): | ||
doc = documents[res["index"]] | ||
doc_copy = Document(doc.page_content, metadata=deepcopy(doc.metadata)) | ||
doc_copy.metadata["relevance_score"] = res["relevance_score"] | ||
compressed.append(doc_copy) | ||
return compressed |
Oops, something went wrong.