Skip to content

Commit

Permalink
Use sqlalchemy's 'quote' function to quote table names
Browse files Browse the repository at this point in the history
  • Loading branch information
seut committed Jun 13, 2024
1 parent 0741181 commit a00705b
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 25 deletions.
5 changes: 3 additions & 2 deletions cratedb_toolkit/testing/testcontainers/cratedb.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@

from cratedb_toolkit.testing.testcontainers.util import KeepaliveContainer, asbool
from cratedb_toolkit.util import DatabaseAdapter
from cratedb_toolkit.util.database import quote_table_name

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -189,7 +188,9 @@ def reset(self, tables: Optional[list] = None):
"""
if tables and self.database:
for reset_table in tables:
self.database.connection.exec_driver_sql(f"DROP TABLE IF EXISTS {quote_table_name(reset_table)};")
self.database.connection.exec_driver_sql(
f"DROP TABLE IF EXISTS {self.database.quote_ident(reset_table)};"
)

def get_connection_url(self, *args, **kwargs):
"""
Expand Down
48 changes: 25 additions & 23 deletions cratedb_toolkit/util/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,26 @@ def __init__(self, dburi: str, echo: bool = False):
self.engine = sa.create_engine(self.dburi, echo=echo)
self.connection = self.engine.connect()

def quote_ident(self, ident: str) -> str:
"""
Quote the given identifier if needed.
In: foo
Out: foo
In: Foo
Out: "Foo"
In: "Foo"
Out: "Foo"
In: foo.bar
Out: "foo.bar"
"""
if ident[0] == '"' and ident[len(ident) - 1] == '"':
return ident
return self.engine.dialect.identifier_preparer.quote(ident=ident)

def run_sql(self, sql: t.Union[str, Path, io.IOBase], records: bool = False, ignore: str = None):
"""
Run SQL statement, and return results, optionally ignoring exceptions.
Expand Down Expand Up @@ -82,7 +102,7 @@ def count_records(self, name: str, errors: Literal["raise", "ignore"] = "raise")
"""
Return number of records in table.
"""
sql = f"SELECT COUNT(*) AS count FROM {quote_table_name(name)};" # noqa: S608
sql = f"SELECT COUNT(*) AS count FROM {self.quote_ident(name)};" # noqa: S608
try:
results = self.run_sql(sql=sql)
except ProgrammingError as ex:
Expand All @@ -96,7 +116,7 @@ def table_exists(self, name: str) -> bool:
"""
Check whether given table exists.
"""
sql = f"SELECT 1 FROM {quote_table_name(name)} LIMIT 1;" # noqa: S608
sql = f"SELECT 1 FROM {self.quote_ident(name)} LIMIT 1;" # noqa: S608
try:
self.run_sql(sql=sql)
return True
Expand All @@ -107,15 +127,15 @@ def refresh_table(self, name: str):
"""
Run a `REFRESH TABLE ...` command.
"""
sql = f"REFRESH TABLE {quote_table_name(name)};" # noqa: S608
sql = f"REFRESH TABLE {self.quote_ident(name)};" # noqa: S608
self.run_sql(sql=sql)
return True

def prune_table(self, name: str, errors: Literal["raise", "ignore"] = "raise"):
"""
Run a `DELETE FROM ...` command.
"""
sql = f"DELETE FROM {quote_table_name(name)};" # noqa: S608
sql = f"DELETE FROM {self.quote_ident(name)};" # noqa: S608
try:
self.run_sql(sql=sql)
except ProgrammingError as ex:
Expand All @@ -129,7 +149,7 @@ def drop_table(self, name: str):
"""
Run a `DROP TABLE ...` command.
"""
sql = f"DROP TABLE IF EXISTS {quote_table_name(name)};" # noqa: S608
sql = f"DROP TABLE IF EXISTS {self.quote_ident(name)};" # noqa: S608
self.run_sql(sql=sql)
return True

Expand Down Expand Up @@ -332,21 +352,3 @@ def decode_database_table(url: str) -> t.Tuple[str, str]:
if url_.scheme == "crate" and not database:
database = url_.query_params.get("schema")
return database, table


def quote_table_name(name: str) -> str:
"""
Quote table name if not happened already.
In: foo
Out: "foo"
In: "foo"
Out: "foo"
In: foo.bar
Out: foo.bar
"""
if '"' not in name and "." not in name:
name = f'"{name}"'
return name

0 comments on commit a00705b

Please sign in to comment.