diff --git a/great_expectations_provider/operators/great_expectations.py b/great_expectations_provider/operators/great_expectations.py index c901e36..6ee0e23 100644 --- a/great_expectations_provider/operators/great_expectations.py +++ b/great_expectations_provider/operators/great_expectations.py @@ -244,9 +244,9 @@ def make_connection_configuration(self) -> Dict[str, str]: raise ValueError(f"Connections does not exist in Airflow for conn_id: {self.conn_id}") self.schema = self.schema or self.conn.schema conn_type = self.conn.conn_type - if conn_type in ("redshift", "postgres", "mysql", "mssql"): + if conn_type in ("redshift", "mysql", "mssql"): odbc_connector = "" - if conn_type in ("redshift", "postgres"): + if conn_type in ("redshift"): odbc_connector = "postgresql+psycopg2" database_name = self.schema elif conn_type == "mysql": @@ -263,6 +263,16 @@ def make_connection_configuration(self) -> Dict[str, str]: f"{odbc_connector}://{self.conn.login}:{self.conn.password}@" f"{self.conn.host}:{self.conn.port}/{database_name}{driver}" ) + elif conn_type == "postgres": + # the schema parameter in the postgres connection is the database name + if self.conn.schema: + postgres_database = self.conn.schema + odbc_connector = "postgresql+psycopg2" + uri_string = f"{odbc_connector}://{self.conn.login}:{self.conn.password}@{self.conn.host}:{self.conn.port}/{postgres_database}" # noqa + else: + raise ValueError( + "Specify the name of the database in the schema parameter of the Postgres connection. See: https://airflow.apache.org/docs/apache-airflow-providers-postgres/stable/connections/postgres.html" # noqa + ) elif conn_type == "snowflake": try: return self.build_snowflake_connection_config_from_hook() diff --git a/tests/operators/test_great_expectations.py b/tests/operators/test_great_expectations.py index 7290900..d3eb5b6 100644 --- a/tests/operators/test_great_expectations.py +++ b/tests/operators/test_great_expectations.py @@ -12,6 +12,7 @@ import logging import os +import tempfile import unittest.mock as mock from pathlib import Path @@ -859,57 +860,68 @@ def test_great_expectations_operator__make_connection_string_snowflake(mocker): def test_great_expectations_operator__make_connection_string_snowflake_pkey(mocker): - private_key_bytes = b"secret" - test_conn_conf = { - "url": URL.create( - drivername="snowflake", - username="user", - password="", - host="account.region-east-1", - database="database/schema", - query={"role": "role", "warehouse": "warehouse", "authenticator": "snowflake", "application": "AIRFLOW"}, - ).render_as_string(hide_password=False), - "connect_args": {"private_key": private_key_bytes}, - } - operator = GreatExpectationsOperator( - task_id="task_id", - data_context_config=in_memory_data_context_config, - data_asset_name="test_runtime_data_asset", - conn_id="snowflake_default", - query_to_validate="SELECT * FROM db;", - expectation_suite_name="suite", - ) - operator.conn = Connection( - conn_id="snowflake_default", - conn_type="snowflake", - host="connection", - login="user", - password="password", - schema="schema", - port=5439, - extra={ - "extra__snowflake__role": "role", - "extra__snowflake__warehouse": "warehouse", - "extra__snowflake__database": "database", - "extra__snowflake__region": "region-east-1", - "extra__snowflake__account": "account", - "extra__snowflake__private_key_file": "/path/to/key.p8", - }, - ) - operator.conn_type = operator.conn.conn_type + # create a temp key file + with tempfile.NamedTemporaryFile(delete=True) as temp_file: + private_key_bytes = b"fake_key" + temp_file.write(private_key_bytes) + temp_file.flush() + test_conn_conf = { + "url": URL.create( + drivername="snowflake", + username="user", + password="", + host="account.region-east-1", + database="database/schema", + query={ + "role": "role", + "warehouse": "warehouse", + "authenticator": "snowflake", + "application": "AIRFLOW", + }, + ).render_as_string(hide_password=False), + "connect_args": {"private_key": private_key_bytes}, + } + operator = GreatExpectationsOperator( + task_id="task_id", + data_context_config=in_memory_data_context_config, + data_asset_name="test_runtime_data_asset", + conn_id="snowflake_default", + query_to_validate="SELECT * FROM db;", + expectation_suite_name="suite", + ) + operator.conn = Connection( + conn_id="snowflake_default", + conn_type="snowflake", + host="connection", + login="user", + password="password", + schema="schema", + port=5439, + extra={ + "extra__snowflake__role": "role", + "extra__snowflake__warehouse": "warehouse", + "extra__snowflake__database": "database", + "extra__snowflake__region": "region-east-1", + "extra__snowflake__account": "account", + "extra__snowflake__private_key_file": temp_file.name, + }, + ) + operator.conn_type = operator.conn.conn_type - mocker.patch( - "airflow.providers.snowflake.hooks.snowflake.SnowflakeHook.get_connection", return_value=operator.conn - ) - mocker.patch("great_expectations_provider.operators.great_expectations.Path.read_bytes", return_value=b"dummy") - mocked_key = mock.MagicMock(default_backend()) - mocked_key.private_bytes = mock.MagicMock(return_value=private_key_bytes) - mocker.patch( - "cryptography.hazmat.primitives.serialization.load_pem_private_key", - return_value=mocked_key, - ) + mocker.patch( + "airflow.providers.snowflake.hooks.snowflake.SnowflakeHook.get_connection", return_value=operator.conn + ) + mocker.patch( + "great_expectations_provider.operators.great_expectations.Path.read_bytes", return_value=b"fake_key" + ) + mocked_key = mock.MagicMock(default_backend()) + mocked_key.private_bytes = mock.MagicMock(return_value=private_key_bytes) + mocker.patch( + "cryptography.hazmat.primitives.serialization.load_pem_private_key", + return_value=mocked_key, + ) - assert operator.make_connection_configuration() == test_conn_conf + assert operator.make_connection_configuration() == test_conn_conf def test_great_expectations_operator__make_connection_string_sqlite():