From 8a7c37281c49751de5ea5ec14ccdb2dfe86e6d72 Mon Sep 17 00:00:00 2001 From: Jarek Potiuk Date: Wed, 26 Aug 2020 16:38:57 +0200 Subject: [PATCH] Untangle cyclic deps configuration <> secrets (#10559) --- airflow/configuration.py | 57 +++++++++++- airflow/hooks/base_hook.py | 5 +- airflow/models/connection.py | 19 +++- airflow/models/variable.py | 20 +++- airflow/secrets/__init__.py | 93 +------------------ .../aws/secrets/test_systems_manager.py | 2 +- tests/secrets/test_secrets.py | 14 +-- 7 files changed, 99 insertions(+), 111 deletions(-) diff --git a/airflow/configuration.py b/airflow/configuration.py index cbd281cefffb9..5d878f4d3eda2 100644 --- a/airflow/configuration.py +++ b/airflow/configuration.py @@ -17,6 +17,7 @@ # under the License. import copy +import json import logging import multiprocessing import os @@ -30,12 +31,14 @@ from collections import OrderedDict # Ignored Mypy on configparser because it thinks the configparser module has no _UNSET attribute from configparser import _UNSET, ConfigParser, NoOptionError, NoSectionError # type: ignore -from typing import Dict, Optional, Tuple, Union +from json.decoder import JSONDecodeError +from typing import Dict, List, Optional, Tuple, Union import yaml from cryptography.fernet import Fernet from airflow.exceptions import AirflowConfigException +from airflow.secrets import DEFAULT_SECRETS_SEARCH_PATH, BaseSecretsBackend from airflow.utils.module_loading import import_string log = logging.getLogger(__name__) @@ -87,8 +90,7 @@ def run_command(command): def _get_config_value_from_secret_backend(config_key): """Get Config option values from Secret Backend""" - from airflow import secrets - secrets_client = secrets.get_custom_secret_backend() + secrets_client = get_custom_secret_backend() if not secrets_client: return None return secrets_client.get_config(config_key) @@ -951,3 +953,52 @@ def set(*args, **kwargs): # noqa pylint: disable=redefined-builtin stacklevel=2 ) return conf.set(*args, **kwargs) + + +def ensure_secrets_loaded() -> List[BaseSecretsBackend]: + """ + Ensure that all secrets backends are loaded. + If the secrets_backend_list contains only 2 default backends, reload it. + """ + # Check if the secrets_backend_list contains only 2 default backends + if len(secrets_backend_list) == 2: + return initialize_secrets_backends() + return secrets_backend_list + + +def get_custom_secret_backend() -> Optional[BaseSecretsBackend]: + """Get Secret Backend if defined in airflow.cfg""" + secrets_backend_cls = conf.getimport(section='secrets', key='backend') + + if secrets_backend_cls: + try: + alternative_secrets_config_dict = json.loads( + conf.get(section='secrets', key='backend_kwargs', fallback='{}') + ) + except JSONDecodeError: + alternative_secrets_config_dict = {} + + return secrets_backend_cls(**alternative_secrets_config_dict) + return None + + +def initialize_secrets_backends() -> List[BaseSecretsBackend]: + """ + * import secrets backend classes + * instantiate them and return them in a list + """ + backend_list = [] + + custom_secret_backend = get_custom_secret_backend() + + if custom_secret_backend is not None: + backend_list.append(custom_secret_backend) + + for class_name in DEFAULT_SECRETS_SEARCH_PATH: + secrets_backend_cls = import_string(class_name) + backend_list.append(secrets_backend_cls()) + + return backend_list + + +secrets_backend_list = initialize_secrets_backends() diff --git a/airflow/hooks/base_hook.py b/airflow/hooks/base_hook.py index cd8a8cd6f2faa..9a54077519ea1 100644 --- a/airflow/hooks/base_hook.py +++ b/airflow/hooks/base_hook.py @@ -20,7 +20,6 @@ import random from typing import Any, List -from airflow import secrets from airflow.models.connection import Connection from airflow.utils.log.logging_mixin import LoggingMixin @@ -44,7 +43,7 @@ def get_connections(cls, conn_id: str) -> List[Connection]: :param conn_id: connection id :return: array of connections """ - return secrets.get_connections(conn_id) + return Connection.get_connections_from_secrets(conn_id) @classmethod def get_connection(cls, conn_id: str) -> Connection: @@ -54,7 +53,7 @@ def get_connection(cls, conn_id: str) -> Connection: :param conn_id: connection id :return: connection """ - conn = random.choice(list(cls.get_connections(conn_id))) + conn = random.choice(cls.get_connections(conn_id)) if conn.host: log.info( "Using connection to: id: %s. Host: %s, Port: %s, Schema: %s, Login: %s, Password: %s, " diff --git a/airflow/models/connection.py b/airflow/models/connection.py index 4bedc22ad94b5..44e8e576a3444 100644 --- a/airflow/models/connection.py +++ b/airflow/models/connection.py @@ -19,14 +19,15 @@ import json import warnings from json import JSONDecodeError -from typing import Dict, Optional +from typing import Dict, List, Optional from urllib.parse import parse_qsl, quote, unquote, urlencode, urlparse from sqlalchemy import Boolean, Column, Integer, String from sqlalchemy.ext.declarative import declared_attr from sqlalchemy.orm import synonym -from airflow.exceptions import AirflowException +from airflow.configuration import ensure_secrets_loaded +from airflow.exceptions import AirflowException, AirflowNotFoundException from airflow.models.base import ID_LEN, Base from airflow.models.crypto import get_fernet from airflow.utils.log.logging_mixin import LoggingMixin @@ -380,3 +381,17 @@ def extra_dejson(self) -> Dict: self.log.error("Failed parsing the json for conn_id %s", self.conn_id) return obj + + @classmethod + def get_connections_from_secrets(cls, conn_id: str) -> List['Connection']: + """ + Get all connections as an iterable. + + :param conn_id: connection id + :return: array of connections + """ + for secrets_backend in ensure_secrets_loaded(): + conn_list = secrets_backend.get_connections(conn_id=conn_id) + if conn_list: + return list(conn_list) + raise AirflowNotFoundException("The conn_id `{0}` isn't defined".format(conn_id)) diff --git a/airflow/models/variable.py b/airflow/models/variable.py index a9222c4468c22..17c8fda2411dc 100644 --- a/airflow/models/variable.py +++ b/airflow/models/variable.py @@ -17,16 +17,16 @@ # under the License. import json -from typing import Any +from typing import Any, Optional from cryptography.fernet import InvalidToken as InvalidFernetToken from sqlalchemy import Boolean, Column, Integer, String, Text from sqlalchemy.ext.declarative import declared_attr from sqlalchemy.orm import Session, synonym +from airflow.configuration import ensure_secrets_loaded from airflow.models.base import ID_LEN, Base from airflow.models.crypto import get_fernet -from airflow.secrets import get_variable from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.session import provide_session @@ -126,7 +126,7 @@ def get( :param default_var: Default value of the Variable if the Variable doesn't exists :param deserialize_json: Deserialize the value to a Python dict """ - var_val = get_variable(key=key) + var_val = Variable.get_variable_from_secrets(key=key) if var_val is None: if default_var is not cls.__NO_DEFAULT_SENTINEL: return default_var @@ -181,3 +181,17 @@ def rotate_fernet_key(self): fernet = get_fernet() if self._val and self.is_encrypted: self._val = fernet.rotate(self._val.encode('utf-8')).decode() + + @staticmethod + def get_variable_from_secrets(key: str) -> Optional[str]: + """ + Get Airflow Variable by iterating over all Secret Backends. + + :param key: Variable Key + :return: Variable Value + """ + for secrets_backend in ensure_secrets_loaded(): + var_val = secrets_backend.get_variable(key=key) + if var_val is not None: + return var_val + return None diff --git a/airflow/secrets/__init__.py b/airflow/secrets/__init__.py index 639e62c8dac48..5b12c8a300137 100644 --- a/airflow/secrets/__init__.py +++ b/airflow/secrets/__init__.py @@ -22,102 +22,11 @@ * Metastore database * AWS SSM Parameter store """ -__all__ = ['BaseSecretsBackend', 'get_connections', 'get_variable', 'get_custom_secret_backend'] +__all__ = ['BaseSecretsBackend', 'DEFAULT_SECRETS_SEARCH_PATH'] -import json -from json import JSONDecodeError -from typing import TYPE_CHECKING, List, Optional - -from airflow.configuration import conf -from airflow.exceptions import AirflowNotFoundException from airflow.secrets.base_secrets import BaseSecretsBackend -from airflow.utils.module_loading import import_string - -if TYPE_CHECKING: - from airflow.models.connection import Connection - -CONFIG_SECTION = "secrets" DEFAULT_SECRETS_SEARCH_PATH = [ "airflow.secrets.environment_variables.EnvironmentVariablesBackend", "airflow.secrets.metastore.MetastoreBackend", ] - - -def get_connections(conn_id: str) -> List['Connection']: - """ - Get all connections as an iterable. - - :param conn_id: connection id - :return: array of connections - """ - for secrets_backend in ensure_secrets_loaded(): - conn_list = secrets_backend.get_connections(conn_id=conn_id) - if conn_list: - return list(conn_list) - - raise AirflowNotFoundException("The conn_id `{0}` isn't defined".format(conn_id)) - - -def get_variable(key: str) -> Optional[str]: - """ - Get Airflow Variable by iterating over all Secret Backends. - - :param key: Variable Key - :return: Variable Value - """ - for secrets_backend in ensure_secrets_loaded(): - var_val = secrets_backend.get_variable(key=key) - if var_val is not None: - return var_val - - return None - - -def get_custom_secret_backend() -> Optional[BaseSecretsBackend]: - """Get Secret Backend if defined in airflow.cfg""" - secrets_backend_cls = conf.getimport(section='secrets', key='backend') - - if secrets_backend_cls: - try: - alternative_secrets_config_dict = json.loads( - conf.get(section=CONFIG_SECTION, key='backend_kwargs', fallback='{}') - ) - except JSONDecodeError: - alternative_secrets_config_dict = {} - - return secrets_backend_cls(**alternative_secrets_config_dict) - return None - - -def initialize_secrets_backends() -> List[BaseSecretsBackend]: - """ - * import secrets backend classes - * instantiate them and return them in a list - """ - backend_list = [] - - custom_secret_backend = get_custom_secret_backend() - - if custom_secret_backend is not None: - backend_list.append(custom_secret_backend) - - for class_name in DEFAULT_SECRETS_SEARCH_PATH: - secrets_backend_cls = import_string(class_name) - backend_list.append(secrets_backend_cls()) - - return backend_list - - -def ensure_secrets_loaded() -> List[BaseSecretsBackend]: - """ - Ensure that all secrets backends are loaded. - If the secrets_backend_list contains only 2 default backends, reload it. - """ - # Check if the secrets_backend_list contains only 2 default backends - if len(secrets_backend_list) == 2: - return initialize_secrets_backends() - return secrets_backend_list - - -secrets_backend_list = initialize_secrets_backends() diff --git a/tests/providers/amazon/aws/secrets/test_systems_manager.py b/tests/providers/amazon/aws/secrets/test_systems_manager.py index d43fadc005559..fa7aaf72811ad 100644 --- a/tests/providers/amazon/aws/secrets/test_systems_manager.py +++ b/tests/providers/amazon/aws/secrets/test_systems_manager.py @@ -19,8 +19,8 @@ from moto import mock_ssm +from airflow.configuration import initialize_secrets_backends from airflow.providers.amazon.aws.secrets.systems_manager import SystemsManagerParameterStoreBackend -from airflow.secrets import initialize_secrets_backends from tests.test_utils.config import conf_vars diff --git a/tests/secrets/test_secrets.py b/tests/secrets/test_secrets.py index 7bb1ed37c8241..53a11a14b6c90 100644 --- a/tests/secrets/test_secrets.py +++ b/tests/secrets/test_secrets.py @@ -19,8 +19,8 @@ import unittest from unittest import mock -from airflow.models import Variable -from airflow.secrets import ensure_secrets_loaded, get_connections, get_variable, initialize_secrets_backends +from airflow.configuration import ensure_secrets_loaded, initialize_secrets_backends +from airflow.models import Connection, Variable from tests.test_utils.config import conf_vars from tests.test_utils.db import clear_db_variables @@ -30,7 +30,7 @@ class TestConnectionsFromSecrets(unittest.TestCase): @mock.patch("airflow.secrets.environment_variables.EnvironmentVariablesBackend.get_connections") def test_get_connections_second_try(self, mock_env_get, mock_meta_get): mock_env_get.side_effect = [[]] # return empty list - get_connections("fake_conn_id") + Connection.get_connections_from_secrets("fake_conn_id") mock_meta_get.assert_called_once_with(conn_id="fake_conn_id") mock_env_get.assert_called_once_with(conn_id="fake_conn_id") @@ -38,7 +38,7 @@ def test_get_connections_second_try(self, mock_env_get, mock_meta_get): @mock.patch("airflow.secrets.environment_variables.EnvironmentVariablesBackend.get_connections") def test_get_connections_first_try(self, mock_env_get, mock_meta_get): mock_env_get.side_effect = [["something"]] # returns nonempty list - get_connections("fake_conn_id") + Connection.get_connections_from_secrets("fake_conn_id") mock_env_get.assert_called_once_with(conn_id="fake_conn_id") mock_meta_get.not_called() @@ -85,7 +85,7 @@ def test_backend_fallback_to_env_var(self, mock_get_uri): backend_classes = [backend.__class__.__name__ for backend in backends] self.assertIn('SystemsManagerParameterStoreBackend', backend_classes) - uri = get_connections(conn_id="test_mysql") + uri = Connection.get_connections_from_secrets(conn_id="test_mysql") # Assert that SystemsManagerParameterStoreBackend.get_conn_uri was called mock_get_uri.assert_called_once_with(conn_id='test_mysql') @@ -109,7 +109,7 @@ def test_get_variable_second_try(self, mock_env_get, mock_meta_get): Metastore DB """ mock_env_get.return_value = None - get_variable("fake_var_key") + Variable.get_variable_from_secrets("fake_var_key") mock_meta_get.assert_called_once_with(key="fake_var_key") mock_env_get.assert_called_once_with(key="fake_var_key") @@ -121,7 +121,7 @@ def test_get_variable_first_try(self, mock_env_get, mock_meta_get): Metastore DB """ mock_env_get.return_value = [["something"]] # returns nonempty list - get_variable("fake_var_key") + Variable.get_variable_from_secrets("fake_var_key") mock_env_get.assert_called_once_with(key="fake_var_key") mock_meta_get.not_called()