Skip to content

Commit

Permalink
Pull the database name for the Postgres connection only from the conn…
Browse files Browse the repository at this point in the history
…ection definition (#117)
  • Loading branch information
TJaniF authored Oct 10, 2023
1 parent 09dbdac commit 71280b7
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 51 deletions.
14 changes: 12 additions & 2 deletions great_expectations_provider/operators/great_expectations.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,9 +244,9 @@ def make_connection_configuration(self) -> Dict[str, str]:
raise ValueError(f"Connections does not exist in Airflow for conn_id: {self.conn_id}")
self.schema = self.schema or self.conn.schema
conn_type = self.conn.conn_type
if conn_type in ("redshift", "postgres", "mysql", "mssql"):
if conn_type in ("redshift", "mysql", "mssql"):
odbc_connector = ""
if conn_type in ("redshift", "postgres"):
if conn_type in ("redshift"):
odbc_connector = "postgresql+psycopg2"
database_name = self.schema
elif conn_type == "mysql":
Expand All @@ -263,6 +263,16 @@ def make_connection_configuration(self) -> Dict[str, str]:
f"{odbc_connector}://{self.conn.login}:{self.conn.password}@"
f"{self.conn.host}:{self.conn.port}/{database_name}{driver}"
)
elif conn_type == "postgres":
# the schema parameter in the postgres connection is the database name
if self.conn.schema:
postgres_database = self.conn.schema
odbc_connector = "postgresql+psycopg2"
uri_string = f"{odbc_connector}://{self.conn.login}:{self.conn.password}@{self.conn.host}:{self.conn.port}/{postgres_database}" # noqa
else:
raise ValueError(
"Specify the name of the database in the schema parameter of the Postgres connection. See: https://airflow.apache.org/docs/apache-airflow-providers-postgres/stable/connections/postgres.html" # noqa
)
elif conn_type == "snowflake":
try:
return self.build_snowflake_connection_config_from_hook()
Expand Down
110 changes: 61 additions & 49 deletions tests/operators/test_great_expectations.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import logging
import os
import tempfile
import unittest.mock as mock
from pathlib import Path

Expand Down Expand Up @@ -859,57 +860,68 @@ def test_great_expectations_operator__make_connection_string_snowflake(mocker):


def test_great_expectations_operator__make_connection_string_snowflake_pkey(mocker):
private_key_bytes = b"secret"
test_conn_conf = {
"url": URL.create(
drivername="snowflake",
username="user",
password="",
host="account.region-east-1",
database="database/schema",
query={"role": "role", "warehouse": "warehouse", "authenticator": "snowflake", "application": "AIRFLOW"},
).render_as_string(hide_password=False),
"connect_args": {"private_key": private_key_bytes},
}
operator = GreatExpectationsOperator(
task_id="task_id",
data_context_config=in_memory_data_context_config,
data_asset_name="test_runtime_data_asset",
conn_id="snowflake_default",
query_to_validate="SELECT * FROM db;",
expectation_suite_name="suite",
)
operator.conn = Connection(
conn_id="snowflake_default",
conn_type="snowflake",
host="connection",
login="user",
password="password",
schema="schema",
port=5439,
extra={
"extra__snowflake__role": "role",
"extra__snowflake__warehouse": "warehouse",
"extra__snowflake__database": "database",
"extra__snowflake__region": "region-east-1",
"extra__snowflake__account": "account",
"extra__snowflake__private_key_file": "/path/to/key.p8",
},
)
operator.conn_type = operator.conn.conn_type
# create a temp key file
with tempfile.NamedTemporaryFile(delete=True) as temp_file:
private_key_bytes = b"fake_key"
temp_file.write(private_key_bytes)
temp_file.flush()
test_conn_conf = {
"url": URL.create(
drivername="snowflake",
username="user",
password="",
host="account.region-east-1",
database="database/schema",
query={
"role": "role",
"warehouse": "warehouse",
"authenticator": "snowflake",
"application": "AIRFLOW",
},
).render_as_string(hide_password=False),
"connect_args": {"private_key": private_key_bytes},
}
operator = GreatExpectationsOperator(
task_id="task_id",
data_context_config=in_memory_data_context_config,
data_asset_name="test_runtime_data_asset",
conn_id="snowflake_default",
query_to_validate="SELECT * FROM db;",
expectation_suite_name="suite",
)
operator.conn = Connection(
conn_id="snowflake_default",
conn_type="snowflake",
host="connection",
login="user",
password="password",
schema="schema",
port=5439,
extra={
"extra__snowflake__role": "role",
"extra__snowflake__warehouse": "warehouse",
"extra__snowflake__database": "database",
"extra__snowflake__region": "region-east-1",
"extra__snowflake__account": "account",
"extra__snowflake__private_key_file": temp_file.name,
},
)
operator.conn_type = operator.conn.conn_type

mocker.patch(
"airflow.providers.snowflake.hooks.snowflake.SnowflakeHook.get_connection", return_value=operator.conn
)
mocker.patch("great_expectations_provider.operators.great_expectations.Path.read_bytes", return_value=b"dummy")
mocked_key = mock.MagicMock(default_backend())
mocked_key.private_bytes = mock.MagicMock(return_value=private_key_bytes)
mocker.patch(
"cryptography.hazmat.primitives.serialization.load_pem_private_key",
return_value=mocked_key,
)
mocker.patch(
"airflow.providers.snowflake.hooks.snowflake.SnowflakeHook.get_connection", return_value=operator.conn
)
mocker.patch(
"great_expectations_provider.operators.great_expectations.Path.read_bytes", return_value=b"fake_key"
)
mocked_key = mock.MagicMock(default_backend())
mocked_key.private_bytes = mock.MagicMock(return_value=private_key_bytes)
mocker.patch(
"cryptography.hazmat.primitives.serialization.load_pem_private_key",
return_value=mocked_key,
)

assert operator.make_connection_configuration() == test_conn_conf
assert operator.make_connection_configuration() == test_conn_conf


def test_great_expectations_operator__make_connection_string_sqlite():
Expand Down

0 comments on commit 71280b7

Please sign in to comment.