Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resolve: VectorSearch enabled SQLChain? #7454

Closed
wants to merge 57 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
a436cd9
constructor / / implementation
mpskex Apr 17, 2023
92ed371
Merge branch 'hwchase17:master' into feature/vectorstore_myscale
mpskex Apr 17, 2023
8d5da30
update
mpskex Apr 18, 2023
23232aa
update impl
mpskex Apr 19, 2023
0fe57d4
Merge branch 'hwchase17:master' into feature/vectorstore_myscale
mpskex Apr 19, 2023
5fd3bfc
change default
mpskex Apr 19, 2023
2846cb4
Merge branch 'hwchase17:master' into feature/vectorstore_myscale
mpskex Apr 20, 2023
d33bf44
add unit tests
mpskex Apr 20, 2023
fd47d96
Merge branch 'master' into feature/vectorstore_myscale
mpskex Apr 20, 2023
c2298f4
unittest passed
mpskex Apr 20, 2023
9b9e96f
add escape string
mpskex Apr 20, 2023
174de9f
format / lint
mpskex Apr 20, 2023
cadad67
update poetry.lock
mpskex Apr 20, 2023
d015562
revised `__repr__` docstrings
mpskex Apr 21, 2023
9460047
Merge branch 'master' into feature/vectorstore_myscale
mpskex Apr 21, 2023
6f6f27e
Merge pull request #1 from myscale/feature/vectorstore_myscale
mpskex Apr 21, 2023
2baa88b
change `tqdm` into an optional package
mpskex Apr 22, 2023
f6a3a44
fixed typo in myscale docs
mpskex Jun 25, 2023
627a0cb
Merge branch 'hwchase17:master' into master
mpskex Jun 25, 2023
3a763b5
revised prompted operator names
mpskex Jun 25, 2023
1eee8ec
improve sqlalchemy support
mpskex Jul 4, 2023
0bb1e84
Merge branch 'hwchase17:master' into myscale/sql_self_query
mpskex Jul 5, 2023
df4a763
add sql_cmd_parser and examples
mpskex Jul 5, 2023
ea767aa
extensive for other VectorSQL
mpskex Jul 5, 2023
c7c587c
update interface to output parser
mpskex Jul 6, 2023
3a80da4
Merge branch 'myscale/improve_string_pattern_match' into preview
mpskex Jul 6, 2023
98f58e3
Merge branch 'hwchase17:master' into preview
mpskex Jul 6, 2023
c02bc4f
Merge branch 'hwchase17:master' into myscale/sql_self_query
mpskex Jul 6, 2023
b8c1d8e
add sql database retriever
mpskex Jul 7, 2023
3cd8300
revised sql database chain retriever
mpskex Jul 7, 2023
a774bf0
Merge branch 'preview' into myscale/sql_self_query
mpskex Jul 7, 2023
f799f15
lint and format and tests
mpskex Jul 7, 2023
eb02a85
format and fix BaseRow import
mpskex Jul 7, 2023
00e6369
add aget method
mpskex Jul 7, 2023
4e5f90f
Merge branch 'hwchase17:master' into myscale/sql_self_query
mpskex Jul 7, 2023
00b5723
revised docs
mpskex Jul 7, 2023
49f3c85
revised sql self query notebook
mpskex Jul 8, 2023
d9242dc
Merge branch 'hwchase17:master' into myscale/sql_self_query
mpskex Jul 10, 2023
6879ecd
format and lint
mpskex Jul 10, 2023
9b672c3
revised according to comments in #7454
mpskex Jul 13, 2023
90d7d84
Merge branch 'hwchase17:master' into myscale/sql_self_query
mpskex Jul 13, 2023
b00c2dc
lint and fix type
mpskex Jul 13, 2023
416a3fd
revert back comparator prompt
mpskex Jul 13, 2023
d50111e
Merge branch 'hwchase17:master' into myscale/sql_self_query
mpskex Jul 20, 2023
22f17d1
Merge branch 'master' into myscale/sql_self_query
mpskex Jul 24, 2023
f88509c
Merge branch 'master' into myscale/sql_self_query
mpskex Jul 24, 2023
870171f
revert back & relint & reformat
mpskex Jul 24, 2023
3034ef9
add dependency
mpskex Jul 24, 2023
94e2c9d
Merge branch 'master' into myscale/sql_self_query
Jul 31, 2023
3d98141
Merge branch 'langchain-ai:master' into myscale/sql_self_query
mpskex Aug 7, 2023
3801ff6
Merge branch 'langchain-ai:master' into myscale/sql_self_query
mpskex Aug 17, 2023
85dd781
Merge branch 'master' into myscale/sql_self_query
mpskex Aug 23, 2023
f569ca5
lint & add notebook
mpskex Aug 23, 2023
6b6780b
fixed lint
mpskex Aug 24, 2023
878ad52
Merge branch 'master' into myscale/sql_self_query
mpskex Sep 4, 2023
a3a74a3
update poetry.lock
mpskex Sep 4, 2023
82f8314
move it out of sqldatabasechain
mpskex Sep 4, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "245065c6",
"metadata": {},
"source": [
"# Vector SQL Retriever with MyScale\n",
"\n",
">[MyScale](https://docs.myscale.com/en/) is an integrated vector database. You can access your database in SQL and also from here, LangChain. MyScale can make a use of [various data types and functions for filters](https://blog.myscale.com/2023/06/06/why-integrated-database-solution-can-boost-your-llm-apps/#filter-on-anything-without-constraints). It will boost up your LLM app no matter if you are scaling up your data or expand your system to broader application."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0246c5bf",
"metadata": {},
"outputs": [],
"source": [
"!pip3 install clickhouse-sqlalchemy InstructorEmbedding sentence_transformers openai langchain-experimental"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7585d2c3",
"metadata": {},
"outputs": [],
"source": [
"\n",
"from os import environ\n",
"import getpass\n",
"from typing import Dict, Any\n",
"from langchain import OpenAI, SQLDatabase, LLMChain\n",
"from langchain_experimental.sql.vector_sql import VectorSQLDatabaseChain\n",
"from sqlalchemy import create_engine, Column, MetaData\n",
"from langchain import PromptTemplate\n",
"\n",
"\n",
"from sqlalchemy import create_engine\n",
"\n",
"MYSCALE_HOST = \"msc-1decbcc9.us-east-1.aws.staging.myscale.cloud\"\n",
"MYSCALE_PORT = 443\n",
"MYSCALE_USER = \"chatdata\"\n",
"MYSCALE_PASSWORD = \"myscale_rocks\"\n",
"OPENAI_API_KEY = getpass.getpass(\"OpenAI API Key:\")\n",
"\n",
"engine = create_engine(\n",
" f\"clickhouse://{MYSCALE_USER}:{MYSCALE_PASSWORD}@{MYSCALE_HOST}:{MYSCALE_PORT}/default?protocol=https\"\n",
")\n",
"metadata = MetaData(bind=engine)\n",
"environ[\"OPENAI_API_KEY\"] = OPENAI_API_KEY"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e08d9ddc",
"metadata": {},
"outputs": [],
"source": [
"from langchain.embeddings import HuggingFaceInstructEmbeddings\n",
"from langchain_experimental.sql.vector_sql import VectorSQLOutputParser\n",
"\n",
"output_parser = VectorSQLOutputParser.from_embeddings(\n",
" model=HuggingFaceInstructEmbeddings(\n",
" model_name=\"hkunlp/instructor-xl\", model_kwargs={\"device\": \"cpu\"}\n",
" )\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "84b705b2",
"metadata": {},
"outputs": [],
"source": [
"\n",
"from langchain.llms import OpenAI\n",
"from langchain.callbacks import StdOutCallbackHandler\n",
"\n",
"from langchain.utilities.sql_database import SQLDatabase\n",
"from langchain_experimental.sql.prompt import MYSCALE_PROMPT\n",
"from langchain_experimental.sql.vector_sql import VectorSQLDatabaseChain\n",
"\n",
"chain = VectorSQLDatabaseChain(\n",
" llm_chain=LLMChain(\n",
" llm=OpenAI(openai_api_key=OPENAI_API_KEY, temperature=0),\n",
" prompt=MYSCALE_PROMPT,\n",
" ),\n",
" top_k=10,\n",
" return_direct=True,\n",
" sql_cmd_parser=output_parser,\n",
" database=SQLDatabase(engine, None, metadata),\n",
")\n",
"\n",
"import pandas as pd\n",
"\n",
"pd.DataFrame(\n",
" chain.run(\n",
" \"Please give me 10 papers to ask what is PageRank?\",\n",
" callbacks=[StdOutCallbackHandler()],\n",
" )\n",
")"
]
},
{
"cell_type": "markdown",
"id": "6c09cda0",
"metadata": {},
"source": [
"## SQL Database as Retriever"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "734d7ff5",
"metadata": {},
"outputs": [],
"source": [
"from langchain.chat_models import ChatOpenAI\n",
"from langchain.chains.qa_with_sources.retrieval import RetrievalQAWithSourcesChain\n",
"\n",
"from langchain_experimental.sql.vector_sql import VectorSQLDatabaseChain\n",
"from langchain_experimental.retrievers.vector_sql_database \\\n",
" import VectorSQLDatabaseChainRetriever\n",
"from langchain_experimental.sql.prompt import MYSCALE_PROMPT\n",
"from langchain_experimental.sql.vector_sql import VectorSQLRetrieveAllOutputParser\n",
"\n",
"output_parser_retrieve_all = VectorSQLRetrieveAllOutputParser.from_embeddings(\n",
" output_parser.model\n",
")\n",
"\n",
"chain = VectorSQLDatabaseChain.from_llm(\n",
" llm=OpenAI(openai_api_key=OPENAI_API_KEY, temperature=0),\n",
" prompt=MYSCALE_PROMPT,\n",
" top_k=10,\n",
" return_direct=True,\n",
" db=SQLDatabase(engine, None, metadata),\n",
" sql_cmd_parser=output_parser_retrieve_all,\n",
" native_format=True,\n",
")\n",
"\n",
"# You need all those keys to get docs\n",
"retriever = VectorSQLDatabaseChainRetriever(sql_db_chain=chain, page_content_key=\"abstract\")\n",
"\n",
"document_with_metadata_prompt = PromptTemplate(\n",
" input_variables=[\"page_content\", \"id\", \"title\", \"authors\", \"pubdate\", \"categories\"],\n",
" template=\"Content:\\n\\tTitle: {title}\\n\\tAbstract: {page_content}\\n\\tAuthors: {authors}\\n\\tDate of Publication: {pubdate}\\n\\tCategories: {categories}\\nSOURCE: {id}\",\n",
")\n",
"\n",
"chain = RetrievalQAWithSourcesChain.from_chain_type(\n",
" ChatOpenAI(\n",
" model_name=\"gpt-3.5-turbo-16k\", openai_api_key=OPENAI_API_KEY, temperature=0.6\n",
" ),\n",
" retriever=retriever,\n",
" chain_type=\"stuff\",\n",
" chain_type_kwargs={\n",
" \"document_prompt\": document_with_metadata_prompt,\n",
" },\n",
" return_source_documents=True,\n",
")\n",
"ans = chain(\"Please give me 10 papers to ask what is PageRank?\",\n",
" callbacks=[StdOutCallbackHandler()])\n",
"print(ans[\"answer\"])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4948ff25",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
"""Vector SQL Database Chain Retriever"""
from typing import Any, Dict, List

from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.schema import BaseRetriever, Document

from langchain_experimental.sql.vector_sql import VectorSQLDatabaseChain


class VectorSQLDatabaseChainRetriever(BaseRetriever):
"""Retriever that uses SQLDatabase as Retriever"""

sql_db_chain: VectorSQLDatabaseChain
"""SQL Database Chain"""
page_content_key: str = "content"
"""column name for page content of documents"""

def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
ret: List[Dict[str, Any]] = self.sql_db_chain(
query, callbacks=run_manager.get_child(), **kwargs
)["result"]
return [
Document(page_content=r[self.page_content_key], metadata=r) for r in ret
]

async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
raise NotImplementedError
85 changes: 85 additions & 0 deletions libs/experimental/langchain_experimental/sql/prompt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# flake8: noqa
from langchain.prompts.prompt import PromptTemplate


PROMPT_SUFFIX = """Only use the following tables:
{table_info}

Question: {input}"""

_VECTOR_SQL_DEFAULT_TEMPLATE = """You are a {dialect} expert. Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer to the input question.
{dialect} queries has a vector distance function called `DISTANCE(column, array)` to compute relevance to the user's question and sort the feature array column by the relevance.
When the query is asking for {top_k} closest row, you have to use this distance function to calculate distance to entity's array on vector column and order by the distance to retrieve relevant rows.

*NOTICE*: `DISTANCE(column, array)` only accept an array column as its first argument and a `NeuralArray(entity)` as its second argument. You also need a user defined function called `NeuralArray(entity)` to retrieve the entity's array.

Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per {dialect}. You should only order according to the distance function.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use today() function to get the current date, if the question involves "today". `ORDER BY` clause should always be after `WHERE` clause. DO NOT add semicolon to the end of SQL. Pay attention to the comment in table schema.

Use the following format:

Question: "Question here"
SQLQuery: "SQL Query to run"
SQLResult: "Result of the SQLQuery"
Answer: "Final answer here"
"""

VECTOR_SQL_PROMPT = PromptTemplate(
input_variables=["input", "table_info", "dialect", "top_k"],
template=_VECTOR_SQL_DEFAULT_TEMPLATE + PROMPT_SUFFIX,
)


_myscale_prompt = """You are a MyScale expert. Given an input question, first create a syntactically correct MyScale query to run, then look at the results of the query and return the answer to the input question.
MyScale queries has a vector distance function called `DISTANCE(column, array)` to compute relevance to the user's question and sort the feature array column by the relevance.
When the query is asking for {top_k} closest row, you have to use this distance function to calculate distance to entity's array on vector column and order by the distance to retrieve relevant rows.

*NOTICE*: `DISTANCE(column, array)` only accept an array column as its first argument and a `NeuralArray(entity)` as its second argument. You also need a user defined function called `NeuralArray(entity)` to retrieve the entity's array.

Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per MyScale. You should only order according to the distance function.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use today() function to get the current date, if the question involves "today". `ORDER BY` clause should always be after `WHERE` clause. DO NOT add semicolon to the end of SQL. Pay attention to the comment in table schema.

Use the following format:

======== table info ========
<some table infos>

Question: "Question here"
SQLQuery: "SQL Query to run"


Here are some examples:

======== table info ========
CREATE TABLE "ChatPaper" (
abstract String,
id String,
vector Array(Float32),
) ENGINE = ReplicatedReplacingMergeTree()
ORDER BY id
PRIMARY KEY id

Question: What is Feartue Pyramid Network?
SQLQuery: SELECT ChatPaper.title, ChatPaper.id, ChatPaper.authors FROM ChatPaper ORDER BY DISTANCE(vector, NeuralArray(PaperRank contribution)) LIMIT {top_k}


Let's begin:
======== table info ========
{table_info}

Question: {input}
SQLQuery: """

MYSCALE_PROMPT = PromptTemplate(
input_variables=["input", "table_info", "top_k"],
template=_myscale_prompt + PROMPT_SUFFIX,
)


VECTOR_SQL_PROMPTS = {
"myscale": MYSCALE_PROMPT,
}
Loading