Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Text to SQL support #1399

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions fern/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ navigation:
contents:
- page: LLM Backends
path: ./docs/pages/manual/llms.mdx
- page: Context Configurations
path: ./docs/pages/manual/contexts.mdx
- section: User Interface
contents:
- page: User interface (Gradio) Manual
Expand Down
73 changes: 73 additions & 0 deletions fern/docs/pages/manual/contexts.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
## SQL Databases
Text-to-SQL querying has been tested for [MySQL](https://www.mysql.com/) in PrivateGPT.

The DB connection is established using [SQLAlchemy](https://www.sqlalchemy.org/), and allows connecting to a number of databases.
To use this feature, connection details for the DB are required.
A database name and list of tables allow PrivateGPT to select those as the query context.

### Install Dependencies
* Run the following command to install the dependencies for text-to-SQL:

```bash
poetry install --with context_database
```

* Install an appropriate database driver:

```bash
poetry add <db_driver>
```

* Set the following environment variables in `settings.yaml`:

```yaml
context_database:
enabled: true
db_dialect: <db_dialect_here>
db_driver: <db_driver_here>
db_host: <db_hostname_here>
db_user: <db_username_here>
db_password: <db_password_here>
database: <database_name_here>
tables: <comma_separated_table_names_here> # eg: TABLE1,TABLE2,TABLE3
```

### Configuring Database
* Dialect is the system SQLAlchemy uses to communicate with different databases.
* Drivers are the connectors for databases.

To start using other SQL databases as context, set the `context_database.db_dialect` and `context_database.db_driver` properties in the `settings.yaml` file.

This configures a connection string as per the following format:
```bash
db_dialect+db_driver://db_user:db_password@db_host/database
```

##### List of included Dialects in SQLAlchemy

| Database | Dialect | Driver Options |
|-----------------------|------------|---------------------------------------------------------- |
| PostgreSQL | postgresql | psycopg2/mysqlconnector/pg8000/asyncpg/ |
| MySQL | mysql | pymysql/mysqldb/mysqlconnector/asyncmy/aiomysql/cymysql |
| MariaDB | mariadb | pymysql/mysqldb/mariadbconnector/asyncmy/aiomysql/cymysql |
| SQLite | sqlite | pysqlite/aiosqlite/pysqlcipher |
| Oracle | oracle | cx_oracle/python-oracledb |
| Microsoft SQL Server | mssql | pyodbc/pymssql/aioodbc |


##### Example connection strings

| Database | Example |
|-----------------------|----------|
| MySQL | `mysql+pymysql://<username>:<password>@<host>/<dbname>` |
| MariaDB | `mariadb+mariadbconnector://<user>:<password>@<host>[:<port>]/<dbname>` |
| PostgreSQL | `postgresql+psycopg2://user:password@host:port/dbname` |
| SQLite | `sqlite+pysqlite:///file_path` |


##### Additional documentation
Refer to [SQLAlchemy Engine Configuration](https://docs.sqlalchemy.org/en/20/core/engines.html) for documentation about configuring SQLAlchemy and the DB connection string.

List of internal & external Dialects: https://docs.sqlalchemy.org/en/20/dialects/

PostgreSQL Documentation Reference: https://docs.sqlalchemy.org/en/20/dialects/postgresql.html
2 changes: 1 addition & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion private_gpt/components/embedding/embedding_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def __init__(self, settings: Settings) -> None:
cache_folder=str(models_cache_path),
)
case "sagemaker":

from private_gpt.components.embedding.custom.sagemaker import (
SagemakerEmbedding,
)
Expand Down
Empty file.
71 changes: 71 additions & 0 deletions private_gpt/components/nlsql/nlsql_component.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import logging
from typing import Any
from urllib.parse import quote_plus

from injector import inject, singleton
from llama_index import ServiceContext, SQLDatabase, VectorStoreIndex
from llama_index.indices.struct_store import SQLTableRetrieverQueryEngine
from llama_index.objects import ObjectIndex, SQLTableNodeMapping, SQLTableSchema

from private_gpt.settings.settings import Settings

logger = logging.getLogger(__name__)


@singleton
class NLSQLComponent:
sqlalchemy_engine: Any
sql_database: Any
metadata_obj: Any

@inject
def __init__(self, settings: Settings) -> None:
if settings.context_database.enabled:
dialect = settings.context_database.db_dialect
driver = settings.context_database.db_driver
host = settings.context_database.db_host
user = settings.context_database.db_user
password = settings.context_database.db_password
database = settings.context_database.database
tables = settings.context_database.tables
try:
from sqlalchemy import (
MetaData,
)
from sqlalchemy.engine import create_engine

engine = create_engine(
f"{dialect}+{driver}://{user}:%s@{host}/{database}"
% quote_plus(password)
)
except BaseException as error:
raise ValueError(
f"Unable to initialise connection to SQL Database\n{error}"
) from error

metadata_obj = MetaData()
metadata_obj.reflect(engine)
sql_database = SQLDatabase(engine, include_tables=tables)
self.sqlalchemy_engine = engine
self.sql_database = sql_database
self.metadata_obj = metadata_obj

def get_nlsql_query_engine(
self,
service_context: ServiceContext,
) -> SQLTableRetrieverQueryEngine:
table_node_mapping = SQLTableNodeMapping(self.sql_database)
table_schema_objs = []
for table_name in self.metadata_obj.tables:
table_schema_objs.append(SQLTableSchema(table_name=table_name))
obj_index = ObjectIndex.from_objects(
table_schema_objs,
table_node_mapping,
VectorStoreIndex,
service_context=service_context,
)
return SQLTableRetrieverQueryEngine(
service_context=service_context,
sql_database=self.sql_database,
table_retriever=obj_index.as_retriever(similarity_top_k=1),
)
1 change: 0 additions & 1 deletion private_gpt/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@


def create_app(root_injector: Injector) -> FastAPI:

# Start the API
async def bind_injector_to_request(request: Request) -> None:
request.state.injector = root_injector
Expand Down
26 changes: 26 additions & 0 deletions private_gpt/server/chat/chat_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@
BaseChatEngine,
)
from llama_index.indices.postprocessor import MetadataReplacementPostProcessor
from llama_index.indices.struct_store import SQLTableRetrieverQueryEngine
from llama_index.llms import ChatMessage, MessageRole
from llama_index.types import TokenGen
from pydantic import BaseModel

from private_gpt.components.embedding.embedding_component import EmbeddingComponent
from private_gpt.components.llm.llm_component import LLMComponent
from private_gpt.components.nlsql.nlsql_component import NLSQLComponent
from private_gpt.components.node_store.node_store_component import NodeStoreComponent
from private_gpt.components.vector_store.vector_store_component import (
VectorStoreComponent,
Expand All @@ -31,6 +33,11 @@ class CompletionGen(BaseModel):
sources: list[Chunk] | None = None


class SqlQueryResponse(BaseModel):
response: str
sources: None = None


@dataclass
class ChatEngineInput:
system_message: ChatMessage | None = None
Expand Down Expand Up @@ -74,9 +81,11 @@ def __init__(
vector_store_component: VectorStoreComponent,
embedding_component: EmbeddingComponent,
node_store_component: NodeStoreComponent,
nlsql_component: NLSQLComponent,
) -> None:
self.llm_service = llm_component
self.vector_store_component = vector_store_component
self.nlsql_component = nlsql_component
self.storage_context = StorageContext.from_defaults(
vector_store=vector_store_component.vector_store,
docstore=node_store_component.doc_store,
Expand Down Expand Up @@ -116,6 +125,13 @@ def _chat_engine(
service_context=self.service_context,
)

def _nlsql_engine(
self,
) -> SQLTableRetrieverQueryEngine:
return self.nlsql_component.get_nlsql_query_engine(
service_context=self.service_context
)

def stream_chat(
self,
messages: list[ChatMessage],
Expand Down Expand Up @@ -185,3 +201,13 @@ def chat(
sources = [Chunk.from_node(node) for node in wrapped_response.source_nodes]
completion = Completion(response=wrapped_response.response, sources=sources)
return completion

def stream_chat_nlsql(
self,
messages: list[ChatMessage],
) -> SqlQueryResponse:
last_message = str(messages[-1].content)
nlsql_engine = self._nlsql_engine()
response = nlsql_engine.query(last_message)
query = SqlQueryResponse(response=str(response))
return query
34 changes: 34 additions & 0 deletions private_gpt/settings/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,39 @@ class QdrantSettings(BaseModel):
)


class SQLDatabaseSettings(BaseModel):
enabled: bool = Field(
False,
description="Flag to enable SQL Query mode. Disabled by default",
)
db_dialect: str = Field(
None,
description="Supported dialect in SQLAlchemy to be used for connecting with the DBAPI",
)
db_driver: str = Field(
None, description="Drivername of the DBAPI for connecting with Database"
)
db_host: str = Field(
"localhost",
description="Host name of Database server. Defaults to 'localhost'.",
)
db_user: str = Field(
None,
description="Username to be used for accessing the SQL Database Server. Defaults to None.",
)
db_password: str = Field(
None,
description="Password to be used for accessing the SQL Database Server. Defaults to None.",
)
database: str = Field(
None,
description="The database name in which tables are to be queried. Defaults to None.",
)
tables: list[str] | None = Field(
None, description="List of tables to use as context. Defaults to [None]"
)


class Settings(BaseModel):
server: ServerSettings
data: DataSettings
Expand All @@ -228,6 +261,7 @@ class Settings(BaseModel):
openai: OpenAISettings
vectorstore: VectorstoreSettings
qdrant: QdrantSettings | None = None
context_database: SQLDatabaseSettings


"""
Expand Down
5 changes: 5 additions & 0 deletions private_gpt/settings/yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ def load_env_var(_, node) -> str:
env_var = split[0]
value = environ.get(env_var)
default = None if len(split) == 1 else split[1]
if env_var in ["TABLES_LIST"]:
if value is not None:
value = value.split(",")
else:
default = []
if value is None and default is None:
raise ValueError(
f"Environment variable {env_var} is not set and not default was provided"
Expand Down
21 changes: 18 additions & 3 deletions private_gpt/ui/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@

from private_gpt.constants import PROJECT_ROOT_PATH
from private_gpt.di import global_injector
from private_gpt.server.chat.chat_service import ChatService, CompletionGen
from private_gpt.server.chat.chat_service import (
ChatService,
CompletionGen,
SqlQueryResponse,
)
from private_gpt.server.chunks.chunks_service import Chunk, ChunksService
from private_gpt.server.ingest.ingest_service import IngestService
from private_gpt.settings.settings import settings
Expand All @@ -31,6 +35,8 @@
SOURCES_SEPARATOR = "\n\n Sources: \n"

MODES = ["Query Docs", "Search in Docs", "LLM Chat"]
if settings().context_database.enabled:
MODES.append("Query Db")


class Source(BaseModel):
Expand Down Expand Up @@ -78,7 +84,9 @@ def __init__(
self._system_prompt = self._get_default_system_prompt(self.mode)

def _chat(self, message: str, history: list[list[str]], mode: str, *_: Any) -> Any:
def yield_deltas(completion_gen: CompletionGen) -> Iterable[str]:
def yield_deltas(
completion_gen: CompletionGen | SqlQueryResponse, sources: bool = True
) -> Iterable[str]:
full_response: str = ""
stream = completion_gen.response
for delta in stream:
Expand All @@ -88,7 +96,7 @@ def yield_deltas(completion_gen: CompletionGen) -> Iterable[str]:
full_response += delta.delta or ""
yield full_response

if completion_gen.sources:
if sources and completion_gen.sources:
full_response += SOURCES_SEPARATOR
cur_sources = Source.curate_sources(completion_gen.sources)
sources_text = "\n\n\n".join(
Expand Down Expand Up @@ -136,6 +144,13 @@ def build_history() -> list[ChatMessage]:
use_context=True,
)
yield from yield_deltas(query_stream)

case "Query Db":
sql_stream = self._chat_service.stream_chat_nlsql(
messages=all_messages,
)
yield from yield_deltas(sql_stream, False)

case "LLM Chat":
llm_stream = self._chat_service.stream_chat(
messages=all_messages,
Expand Down
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ sentence-transformers = "^2.2.2"
torch = ">=2.0.0, !=2.0.1, !=2.1.0"
transformers = "^4.34.0"

# Dependencies for using Database as context
[tool.poetry.group.context_database]
optional = true
[tool.poetry.group.context_database.dependencies]
cryptography = "^41.0.7"

[tool.poetry.extras]
chroma = ["chromadb"]

Expand Down
1 change: 0 additions & 1 deletion scripts/ingest_folder.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ def _do_ingest_one(self, changed_path: Path) -> None:
logger.addHandler(file_handler)

if __name__ == "__main__":

root_path = Path(args.folder)
if not root_path.exists():
raise ValueError(f"Path {args.folder} does not exist")
Expand Down
11 changes: 11 additions & 0 deletions settings.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,17 @@ vectorstore:
qdrant:
path: local_data/private_gpt/qdrant

context_database:
# Refer https://docs.privategpt.dev/manual/advanced-setup/context-configuration
enabled: false
db_dialect: ${DIALECT:}
db_driver: ${DRIVER:}
db_host: ${HOSTNAME:localhost}
db_user: ${USERNAME:}
db_password: ${PASSWORD:}
database: ${DATABASE:}
tables: ${TABLES_LIST:} # Should be comma separated, for example: TABLE1,TABLE2,TABLE3

local:
prompt_style: "llama2"
llm_hf_repo_id: TheBloke/Mistral-7B-Instruct-v0.1-GGUF
Expand Down
Loading