Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some fixes to text_to_sql_pipeline.py #346

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 31 additions & 34 deletions examples/pipelines/rag/text_to_sql_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@
requirements: llama_index, sqlalchemy, psycopg2-binary
"""

import os
from typing import List, Union, Generator, Iterator
import os
from pydantic import BaseModel
from llama_index.llms.ollama import Ollama
from llama_index.core.query_engine import NLSQLTableQueryEngine

from llama_index.core import SQLDatabase, PromptTemplate
from llama_index.core.query_engine import NLSQLTableQueryEngine
from llama_index.llms.ollama import Ollama
from pydantic import BaseModel
from sqlalchemy import create_engine


Expand All @@ -22,37 +23,37 @@ class Valves(BaseModel):
DB_HOST: str
DB_PORT: str
DB_USER: str
DB_PASSWORD: str
DB_PASSWORD: str
DB_DATABASE: str
DB_TABLE: str
OLLAMA_HOST: str
TEXT_TO_SQL_MODEL: str

TEXT_TO_SQL_MODEL: str

# Update valves/ environment variables based on your selected database
def __init__(self):
"""
Update valves/ environment variables based on your selected database
"""
self.name = "Database RAG Pipeline"
self.engine = None
self.nlsql_response = ""

# Initialize
self.valves = self.Valves(
**{
"pipelines": ["*"], # Connect to all pipelines
"DB_HOST": os.getenv("DB_HOST", "http://localhost"), # Database hostname
"DB_PORT": os.getenv("DB_PORT", 5432), # Database port
"DB_USER": os.getenv("DB_USER", "postgres"), # User to connect to the database with
"DB_PASSWORD": os.getenv("DB_PASSWORD", "password"), # Password to connect to the database with
"DB_DATABASE": os.getenv("DB_DATABASE", "postgres"), # Database to select on the DB instance
"DB_TABLE": os.getenv("DB_TABLE", "table_name"), # Table(s) to run queries against
"OLLAMA_HOST": os.getenv("OLLAMA_HOST", "http://host.docker.internal:11434"), # Make sure to update with the URL of your Ollama host, such as http://localhost:11434 or remote server address
"TEXT_TO_SQL_MODEL": os.getenv("TEXT_TO_SQL_MODEL", "llama3.1:latest") # Model to use for text-to-SQL generation
}
)
self.valves = self.Valves(**{"pipelines": ["*"], # Connect to all pipelines
"DB_HOST": os.getenv("DB_HOST", "host.docker.internal"), # Database hostname
"DB_PORT": os.getenv("DB_PORT", "5432"), # Database port
"DB_USER": os.getenv("DB_USER", "user_name"), # User to connect to the database with
"DB_PASSWORD": os.getenv("DB_PASSWORD", "password"), # Password to connect to the database with
"DB_DATABASE": os.getenv("DB_DATABASE", "postgres"), # Database to select on the DB instance
"DB_TABLE": os.getenv("DB_TABLE", "table_name"), # Table(s) to run queries against
"OLLAMA_HOST": os.getenv("OLLAMA_HOST", "http://host.docker.internal:11434"),
# Make sure to update with the URL of your Ollama host, such as http://localhost:11434 or remote server address
"TEXT_TO_SQL_MODEL": os.getenv("TEXT_TO_SQL_MODEL", "llama3.1:latest") # Model to use for text-to-SQL generation
})

def init_db_connection(self):
# Update your DB connection string based on selected DB engine - current connection string is for Postgres
self.engine = create_engine(f"postgresql+psycopg2://{self.valves.DB_USER}:{self.valves.DB_PASSWORD}@{self.valves.DB_HOST}:{self.valves.DB_PORT}/{self.valves.DB_DATABASE}")
self.engine = create_engine(
f"postgresql+psycopg2://{self.valves.DB_USER}:{self.valves.DB_PASSWORD}@{self.valves.DB_HOST}:{self.valves.DB_PORT}/{self.valves.DB_DATABASE}")
return self.engine

async def on_startup(self):
Expand All @@ -63,16 +64,17 @@ async def on_shutdown(self):
# This function is called when the server is stopped.
pass

def pipe(
self, user_message: str, model_id: str, messages: List[dict], body: dict
) -> Union[str, Generator, Iterator]:
def pipe(self, user_message: str, model_id: str, messages: List[dict], body: dict) -> Union[
str, Generator, Iterator]:
# Debug logging is required to see what SQL query is generated by the LlamaIndex library; enable on Pipelines server if needed
print("pipe method triggered with:", user_message)

# Create database reader for Postgres
sql_database = SQLDatabase(self.engine, include_tables=[self.valves.DB_TABLE])

# Set up LLM connection; uses phi3 model with 128k context limit since some queries have returned 20k+ tokens
llm = Ollama(model=self.valves.TEXT_TO_SQL_MODEL, base_url=self.valves.OLLAMA_HOST, request_timeout=180.0, context_window=30000)
llm = Ollama(model=self.valves.TEXT_TO_SQL_MODEL, base_url=self.valves.OLLAMA_HOST, request_timeout=180.0,
context_window=30000)

# Set up the custom prompt used when generating SQL queries from text
text_to_sql_prompt = """
Expand All @@ -97,14 +99,9 @@ def pipe(

text_to_sql_template = PromptTemplate(text_to_sql_prompt)

query_engine = NLSQLTableQueryEngine(
sql_database=sql_database,
tables=[self.valves.DB_TABLE],
llm=llm,
embed_model="local",
text_to_sql_prompt=text_to_sql_template,
streaming=True
)
query_engine = NLSQLTableQueryEngine(sql_database=sql_database, tables=[self.valves.DB_TABLE], llm=llm,
embed_model="local", text_to_sql_prompt=text_to_sql_template,
streaming=True)

response = query_engine.query(user_message)

Expand Down