Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use sqlalchemy's 'quote' function to quote table names and fix table quoting in cfr exporter #172

Merged
merged 2 commits into from
Jun 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cratedb_toolkit/cfr/systable.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def save(self) -> Path:

path_table_schema = path_schema / f"{ExportSettings.TABLE_FILENAME_PREFIX}{tablename}.sql"
path_table_data = path_data / f"{ExportSettings.TABLE_FILENAME_PREFIX}{tablename}.{self.data_format}"
tablename_out = f"{ExportSettings.TABLE_FILENAME_PREFIX}{tablename}"
tablename_out = self.adapter.quote_relation_name(f"{ExportSettings.TABLE_FILENAME_PREFIX}{tablename}")

# Write schema file.
with open(path_table_schema, "w") as fh_schema:
Expand Down
5 changes: 3 additions & 2 deletions cratedb_toolkit/testing/testcontainers/cratedb.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@

from cratedb_toolkit.testing.testcontainers.util import KeepaliveContainer, asbool
from cratedb_toolkit.util import DatabaseAdapter
from cratedb_toolkit.util.database import quote_table_name

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -189,7 +188,9 @@ def reset(self, tables: Optional[list] = None):
"""
if tables and self.database:
for reset_table in tables:
self.database.connection.exec_driver_sql(f"DROP TABLE IF EXISTS {quote_table_name(reset_table)};")
self.database.connection.exec_driver_sql(
f"DROP TABLE IF EXISTS {self.database.quote_relation_name(reset_table)};"
)

def get_connection_url(self, *args, **kwargs):
"""
Expand Down
60 changes: 37 additions & 23 deletions cratedb_toolkit/util/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,38 @@
self.engine = sa.create_engine(self.dburi, echo=echo)
self.connection = self.engine.connect()

def quote_relation_name(self, ident: str) -> str:

Check warning on line 36 in cratedb_toolkit/util/database.py

View check run for this annotation

Codecov / codecov/patch

cratedb_toolkit/util/database.py#L36

Added line #L36 was not covered by tests
"""
Quote the given, possibly full-qualified, relation name if needed.

In: foo
Out: foo

In: Foo
Out: "Foo"

In: "Foo"
Out: "Foo"

In: foo.bar
Out: "foo"."bar"

In: "foo.bar"
Out: "foo.bar"
"""
if ident[0] == '"' and ident[len(ident) - 1] == '"':
return ident
if "." in ident:
parts = ident.split(".")
if len(parts) > 2:
raise ValueError(f"Invalid relation name {ident}")

Check warning on line 60 in cratedb_toolkit/util/database.py

View check run for this annotation

Codecov / codecov/patch

cratedb_toolkit/util/database.py#L60

Added line #L60 was not covered by tests
return (
self.engine.dialect.identifier_preparer.quote_schema(parts[0])
+ "."
+ self.engine.dialect.identifier_preparer.quote(parts[1])
)
return self.engine.dialect.identifier_preparer.quote(ident=ident)

def run_sql(self, sql: t.Union[str, Path, io.IOBase], records: bool = False, ignore: str = None):
"""
Run SQL statement, and return results, optionally ignoring exceptions.
Expand Down Expand Up @@ -82,7 +114,7 @@
"""
Return number of records in table.
"""
sql = f"SELECT COUNT(*) AS count FROM {quote_table_name(name)};" # noqa: S608
sql = f"SELECT COUNT(*) AS count FROM {self.quote_relation_name(name)};" # noqa: S608
try:
results = self.run_sql(sql=sql)
except ProgrammingError as ex:
Expand All @@ -96,7 +128,7 @@
"""
Check whether given table exists.
"""
sql = f"SELECT 1 FROM {quote_table_name(name)} LIMIT 1;" # noqa: S608
sql = f"SELECT 1 FROM {self.quote_relation_name(name)} LIMIT 1;" # noqa: S608
try:
self.run_sql(sql=sql)
return True
Expand All @@ -107,15 +139,15 @@
"""
Run a `REFRESH TABLE ...` command.
"""
sql = f"REFRESH TABLE {quote_table_name(name)};" # noqa: S608
sql = f"REFRESH TABLE {self.quote_relation_name(name)};" # noqa: S608

Check warning on line 142 in cratedb_toolkit/util/database.py

View check run for this annotation

Codecov / codecov/patch

cratedb_toolkit/util/database.py#L142

Added line #L142 was not covered by tests
self.run_sql(sql=sql)
return True

def prune_table(self, name: str, errors: Literal["raise", "ignore"] = "raise"):
"""
Run a `DELETE FROM ...` command.
"""
sql = f"DELETE FROM {quote_table_name(name)};" # noqa: S608
sql = f"DELETE FROM {self.quote_relation_name(name)};" # noqa: S608
try:
self.run_sql(sql=sql)
except ProgrammingError as ex:
Expand All @@ -129,7 +161,7 @@
"""
Run a `DROP TABLE ...` command.
"""
sql = f"DROP TABLE IF EXISTS {quote_table_name(name)};" # noqa: S608
sql = f"DROP TABLE IF EXISTS {self.quote_relation_name(name)};" # noqa: S608
self.run_sql(sql=sql)
return True

Expand Down Expand Up @@ -332,21 +364,3 @@
if url_.scheme == "crate" and not database:
database = url_.query_params.get("schema")
return database, table


def quote_table_name(name: str) -> str:
"""
Quote table name if not happened already.

In: foo
Out: "foo"

In: "foo"
Out: "foo"

In: foo.bar
Out: foo.bar
"""
if '"' not in name and "." not in name:
name = f'"{name}"'
return name
16 changes: 16 additions & 0 deletions tests/cfr/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,22 @@ def test_cfr_cli_export_failure(cratedb, tmp_path, caplog):
assert result.output == ""


def test_cfr_cli_export_ensure_table_name_is_quoted(cratedb, tmp_path, caplog):
runner = CliRunner(env={"CRATEDB_SQLALCHEMY_URL": cratedb.database.dburi, "CFR_TARGET": str(tmp_path)})
result = runner.invoke(
cli,
args="--debug sys-export",
catch_exceptions=False,
)
assert result.exit_code == 0

path = Path(json.loads(result.output)["path"])
sys_cluster_table_schema = path / "schema" / "sys-cluster.sql"
with open(sys_cluster_table_schema, "r") as f:
content = f.read()
assert '"sys-cluster"' in content, "Table name missing or not quoted"


def test_cfr_cli_import_success(cratedb, tmp_path, caplog):
"""
Verify `ctk cfr sys-import` works.
Expand Down
23 changes: 23 additions & 0 deletions tests/util/database.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import pytest

from cratedb_toolkit.util import DatabaseAdapter


def test_quote_relation_name():
database = DatabaseAdapter(dburi="crate://localhost")
assert database.quote_relation_name("my_table") == "my_table"
assert database.quote_relation_name("my-table") == '"my-table"'
assert database.quote_relation_name("MyTable") == '"MyTable"'
assert database.quote_relation_name('"MyTable"') == '"MyTable"'
assert database.quote_relation_name("my_schema.my_table") == "my_schema.my_table"
assert database.quote_relation_name("my-schema.my_table") == '"my-schema".my_table'
assert database.quote_relation_name('"wrong-quoted-fqn.my_table"') == '"wrong-quoted-fqn.my_table"'
assert database.quote_relation_name('"my_schema"."my_table"') == '"my_schema"."my_table"'
# reserved keyword must be quoted
assert database.quote_relation_name("table") == '"table"'


def test_quote_relation_name_with_invalid_fqn():
database = DatabaseAdapter(dburi="crate://localhost")
with pytest.raises(ValueError):
database.quote_relation_name("my-db.my-schema.my-table")