import json from typing import List, Optional, Dict, Any from pathlib import Path import csv from phi.tools import Toolkit from phi.utils.log import logger try: from sqlalchemy import create_engine, Engine from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.inspection import inspect from sqlalchemy.sql.expression import text except ImportError: raise ImportError("`sqlalchemy` not installed") class SQLTools(Toolkit): def __init__( self, db_url: Optional[str] = None, db_engine: Optional[Engine] = None, user: Optional[str] = None, password: Optional[str] = None, host: Optional[str] = None, port: Optional[int] = None, schema: Optional[str] = None, dialect: Optional[str] = None, tables: Optional[Dict[str, Any]] = None, base_dir: Optional[Path] = None, list_tables: bool = True, describe_table: bool = True, run_sql_query: bool = True, save_query_results: bool = True, ): super().__init__(name="sql_tools") # Get the database engine _engine: Optional[Engine] = db_engine if _engine is None and db_url is not None: _engine = create_engine(db_url) elif user and password and host and port and dialect: if schema is not None: _engine = create_engine(f"{dialect}://{user}:{password}@{host}:{port}/{schema}") else: _engine = create_engine(f"{dialect}://{user}:{password}@{host}:{port}") if _engine is None: raise ValueError("Could not build the database connection") # Database connection self.db_engine: Engine = _engine self.Session: sessionmaker[Session] = sessionmaker(bind=self.db_engine) # Tables this toolkit can access self.tables: Optional[Dict[str, Any]] = tables # Base directory for saving query results self.base_dir: Path = base_dir or Path.cwd() # Register functions in the toolkit if list_tables: self.register(self.list_tables) if describe_table: self.register(self.describe_table) if run_sql_query: self.register(self.run_sql_query) if save_query_results: self.register(self.save_query_results) def list_tables(self) -> str: """Use this function to get a list of table names in the database. Returns: str: list of tables in the database. """ if self.tables is not None: return json.dumps(self.tables) try: table_names = inspect(self.db_engine).get_table_names() logger.debug(f"table_names: {table_names}") return json.dumps(table_names) except Exception as e: logger.error(f"Error getting tables: {e}") return f"Error getting tables: {e}" def describe_table(self, table_name: str) -> str: """Use this function to describe a table. Args: table_name (str): The name of the table to get the schema for. Returns: str: schema of a table """ try: table_names = inspect(self.db_engine) table_schema = table_names.get_columns(table_name) return json.dumps([str(column) for column in table_schema]) except Exception as e: logger.error(f"Error getting table schema: {e}") return f"Error getting table schema: {e}" def run_sql_query(self, query: str, limit: Optional[int] = 10) -> str: """Use this function to run a SQL query and return the result. Args: query (str): The query to run. limit (int, optional): The number of rows to return. Defaults to 10. Use `None` to show all results. Returns: str: Result of the SQL query. Notes: - The result may be empty if the query does not return any data. """ try: return json.dumps(self.run_sql(sql=query, limit=limit), default=str) except Exception as e: logger.error(f"Error running query: {e}") return f"Error running query: {e}" def run_sql(self, sql: str, limit: Optional[int] = None) -> List[dict]: """Internal function to run a sql query. Args: sql (str): The sql query to run. limit (int, optional): The number of rows to return. Defaults to None. Returns: List[dict]: The result of the query. """ logger.debug(f"Running sql |\n{sql}") with self.Session() as sess, sess.begin(): result = sess.execute(text(sql)) # Check if the operation has returned rows. try: if limit: rows = result.fetchmany(limit) else: rows = result.fetchall() return [row._asdict() for row in rows] except Exception as e: logger.error(f"Error while executing SQL: {e}") return [] def save_query_results(self, query: str, file_name: str, chunk_size: int = 1000) -> str: """Executes the given SQL query and saves the results to a file. Args: query (str): The SQL query to execute. file_name (str): The name of the file to save the results to. chunk_size (int, optional): The number of rows to fetch and write at a time. Defaults to 1000. Returns: str: The file name if successful, otherwise returns an error message. """ try: logger.info(f"Executing query and saving results to file: {file_name}") file_path = self.base_dir.joinpath(file_name) with self.Session() as sess, sess.begin(): result = sess.execute(text(query)) with open(file_path, 'w', newline='') as f: writer = csv.writer(f) writer.writerow(result.keys()) # Write header while True: chunk = result.fetchmany(chunk_size) if not chunk: break writer.writerows(chunk) logger.info(f"Query results saved to {file_path}") return str(file_name) except Exception as e: logger.error(f"Error saving query results: {e}") return f"Error saving query results: {e}"