Skip to content

Commit

Permalink
Untangle cyclic deps configuration <> secrets
Browse files Browse the repository at this point in the history
  • Loading branch information
potiuk committed Aug 26, 2020
1 parent 3a34962 commit 5a2babb
Show file tree
Hide file tree
Showing 7 changed files with 99 additions and 111 deletions.
57 changes: 54 additions & 3 deletions airflow/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# under the License.

import copy
import json
import logging
import multiprocessing
import os
Expand All @@ -30,12 +31,14 @@
from collections import OrderedDict
# Ignored Mypy on configparser because it thinks the configparser module has no _UNSET attribute
from configparser import _UNSET, ConfigParser, NoOptionError, NoSectionError # type: ignore
from typing import Dict, Optional, Tuple, Union
from json.decoder import JSONDecodeError
from typing import Dict, List, Optional, Tuple, Union

import yaml
from cryptography.fernet import Fernet

from airflow.exceptions import AirflowConfigException
from airflow.secrets import DEFAULT_SECRETS_SEARCH_PATH, BaseSecretsBackend
from airflow.utils.module_loading import import_string

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -87,8 +90,7 @@ def run_command(command):

def _get_config_value_from_secret_backend(config_key):
"""Get Config option values from Secret Backend"""
from airflow import secrets
secrets_client = secrets.get_custom_secret_backend()
secrets_client = get_custom_secret_backend()
if not secrets_client:
return None
return secrets_client.get_config(config_key)
Expand Down Expand Up @@ -951,3 +953,52 @@ def set(*args, **kwargs): # noqa pylint: disable=redefined-builtin
stacklevel=2
)
return conf.set(*args, **kwargs)


def ensure_secrets_loaded() -> List[BaseSecretsBackend]:
"""
Ensure that all secrets backends are loaded.
If the secrets_backend_list contains only 2 default backends, reload it.
"""
# Check if the secrets_backend_list contains only 2 default backends
if len(secrets_backend_list) == 2:
return initialize_secrets_backends()
return secrets_backend_list


def get_custom_secret_backend() -> Optional[BaseSecretsBackend]:
"""Get Secret Backend if defined in airflow.cfg"""
secrets_backend_cls = conf.getimport(section='secrets', key='backend')

if secrets_backend_cls:
try:
alternative_secrets_config_dict = json.loads(
conf.get(section='secrets', key='backend_kwargs', fallback='{}')
)
except JSONDecodeError:
alternative_secrets_config_dict = {}

return secrets_backend_cls(**alternative_secrets_config_dict)
return None


def initialize_secrets_backends() -> List[BaseSecretsBackend]:
"""
* import secrets backend classes
* instantiate them and return them in a list
"""
backend_list = []

custom_secret_backend = get_custom_secret_backend()

if custom_secret_backend is not None:
backend_list.append(custom_secret_backend)

for class_name in DEFAULT_SECRETS_SEARCH_PATH:
secrets_backend_cls = import_string(class_name)
backend_list.append(secrets_backend_cls())

return backend_list


secrets_backend_list = initialize_secrets_backends()
5 changes: 2 additions & 3 deletions airflow/hooks/base_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import random
from typing import Any, List

from airflow import secrets
from airflow.models.connection import Connection
from airflow.utils.log.logging_mixin import LoggingMixin

Expand All @@ -44,7 +43,7 @@ def get_connections(cls, conn_id: str) -> List[Connection]:
:param conn_id: connection id
:return: array of connections
"""
return secrets.get_connections(conn_id)
return Connection.get_connections_from_secrets(conn_id)

@classmethod
def get_connection(cls, conn_id: str) -> Connection:
Expand All @@ -54,7 +53,7 @@ def get_connection(cls, conn_id: str) -> Connection:
:param conn_id: connection id
:return: connection
"""
conn = random.choice(list(cls.get_connections(conn_id)))
conn = random.choice(cls.get_connections(conn_id))
if conn.host:
log.info(
"Using connection to: id: %s. Host: %s, Port: %s, Schema: %s, Login: %s, Password: %s, "
Expand Down
19 changes: 17 additions & 2 deletions airflow/models/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,15 @@
import json
import warnings
from json import JSONDecodeError
from typing import Dict, Optional
from typing import Dict, List, Optional
from urllib.parse import parse_qsl, quote, unquote, urlencode, urlparse

from sqlalchemy import Boolean, Column, Integer, String
from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.orm import synonym

from airflow.exceptions import AirflowException
from airflow.configuration import ensure_secrets_loaded
from airflow.exceptions import AirflowException, AirflowNotFoundException
from airflow.models.base import ID_LEN, Base
from airflow.models.crypto import get_fernet
from airflow.utils.log.logging_mixin import LoggingMixin
Expand Down Expand Up @@ -380,3 +381,17 @@ def extra_dejson(self) -> Dict:
self.log.error("Failed parsing the json for conn_id %s", self.conn_id)

return obj

@classmethod
def get_connections_from_secrets(cls, conn_id: str) -> List['Connection']:
"""
Get all connections as an iterable.
:param conn_id: connection id
:return: array of connections
"""
for secrets_backend in ensure_secrets_loaded():
conn_list = secrets_backend.get_connections(conn_id=conn_id)
if conn_list:
return list(conn_list)
raise AirflowNotFoundException("The conn_id `{0}` isn't defined".format(conn_id))
20 changes: 17 additions & 3 deletions airflow/models/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,16 @@
# under the License.

import json
from typing import Any
from typing import Any, Optional

from cryptography.fernet import InvalidToken as InvalidFernetToken
from sqlalchemy import Boolean, Column, Integer, String, Text
from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.orm import Session, synonym

from airflow.configuration import ensure_secrets_loaded
from airflow.models.base import ID_LEN, Base
from airflow.models.crypto import get_fernet
from airflow.secrets import get_variable
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.session import provide_session

Expand Down Expand Up @@ -126,7 +126,7 @@ def get(
:param default_var: Default value of the Variable if the Variable doesn't exists
:param deserialize_json: Deserialize the value to a Python dict
"""
var_val = get_variable(key=key)
var_val = Variable.get_variable_from_secrets(key=key)
if var_val is None:
if default_var is not cls.__NO_DEFAULT_SENTINEL:
return default_var
Expand Down Expand Up @@ -181,3 +181,17 @@ def rotate_fernet_key(self):
fernet = get_fernet()
if self._val and self.is_encrypted:
self._val = fernet.rotate(self._val.encode('utf-8')).decode()

@staticmethod
def get_variable_from_secrets(key: str) -> Optional[str]:
"""
Get Airflow Variable by iterating over all Secret Backends.
:param key: Variable Key
:return: Variable Value
"""
for secrets_backend in ensure_secrets_loaded():
var_val = secrets_backend.get_variable(key=key)
if var_val is not None:
return var_val
return None
93 changes: 1 addition & 92 deletions airflow/secrets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,102 +22,11 @@
* Metastore database
* AWS SSM Parameter store
"""
__all__ = ['BaseSecretsBackend', 'get_connections', 'get_variable', 'get_custom_secret_backend']
__all__ = ['BaseSecretsBackend', 'DEFAULT_SECRETS_SEARCH_PATH']

import json
from json import JSONDecodeError
from typing import TYPE_CHECKING, List, Optional

from airflow.configuration import conf
from airflow.exceptions import AirflowNotFoundException
from airflow.secrets.base_secrets import BaseSecretsBackend
from airflow.utils.module_loading import import_string

if TYPE_CHECKING:
from airflow.models.connection import Connection


CONFIG_SECTION = "secrets"
DEFAULT_SECRETS_SEARCH_PATH = [
"airflow.secrets.environment_variables.EnvironmentVariablesBackend",
"airflow.secrets.metastore.MetastoreBackend",
]


def get_connections(conn_id: str) -> List['Connection']:
"""
Get all connections as an iterable.
:param conn_id: connection id
:return: array of connections
"""
for secrets_backend in ensure_secrets_loaded():
conn_list = secrets_backend.get_connections(conn_id=conn_id)
if conn_list:
return list(conn_list)

raise AirflowNotFoundException("The conn_id `{0}` isn't defined".format(conn_id))


def get_variable(key: str) -> Optional[str]:
"""
Get Airflow Variable by iterating over all Secret Backends.
:param key: Variable Key
:return: Variable Value
"""
for secrets_backend in ensure_secrets_loaded():
var_val = secrets_backend.get_variable(key=key)
if var_val is not None:
return var_val

return None


def get_custom_secret_backend() -> Optional[BaseSecretsBackend]:
"""Get Secret Backend if defined in airflow.cfg"""
secrets_backend_cls = conf.getimport(section='secrets', key='backend')

if secrets_backend_cls:
try:
alternative_secrets_config_dict = json.loads(
conf.get(section=CONFIG_SECTION, key='backend_kwargs', fallback='{}')
)
except JSONDecodeError:
alternative_secrets_config_dict = {}

return secrets_backend_cls(**alternative_secrets_config_dict)
return None


def initialize_secrets_backends() -> List[BaseSecretsBackend]:
"""
* import secrets backend classes
* instantiate them and return them in a list
"""
backend_list = []

custom_secret_backend = get_custom_secret_backend()

if custom_secret_backend is not None:
backend_list.append(custom_secret_backend)

for class_name in DEFAULT_SECRETS_SEARCH_PATH:
secrets_backend_cls = import_string(class_name)
backend_list.append(secrets_backend_cls())

return backend_list


def ensure_secrets_loaded() -> List[BaseSecretsBackend]:
"""
Ensure that all secrets backends are loaded.
If the secrets_backend_list contains only 2 default backends, reload it.
"""
# Check if the secrets_backend_list contains only 2 default backends
if len(secrets_backend_list) == 2:
return initialize_secrets_backends()
return secrets_backend_list


secrets_backend_list = initialize_secrets_backends()
2 changes: 1 addition & 1 deletion tests/providers/amazon/aws/secrets/test_systems_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@

from moto import mock_ssm

from airflow.configuration import initialize_secrets_backends
from airflow.providers.amazon.aws.secrets.systems_manager import SystemsManagerParameterStoreBackend
from airflow.secrets import initialize_secrets_backends
from tests.test_utils.config import conf_vars


Expand Down
14 changes: 7 additions & 7 deletions tests/secrets/test_secrets.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
import unittest
from unittest import mock

from airflow.models import Variable
from airflow.secrets import ensure_secrets_loaded, get_connections, get_variable, initialize_secrets_backends
from airflow.configuration import ensure_secrets_loaded, initialize_secrets_backends
from airflow.models import Connection, Variable
from tests.test_utils.config import conf_vars
from tests.test_utils.db import clear_db_variables

Expand All @@ -30,15 +30,15 @@ class TestConnectionsFromSecrets(unittest.TestCase):
@mock.patch("airflow.secrets.environment_variables.EnvironmentVariablesBackend.get_connections")
def test_get_connections_second_try(self, mock_env_get, mock_meta_get):
mock_env_get.side_effect = [[]] # return empty list
get_connections("fake_conn_id")
Connection.get_connections_from_secrets("fake_conn_id")
mock_meta_get.assert_called_once_with(conn_id="fake_conn_id")
mock_env_get.assert_called_once_with(conn_id="fake_conn_id")

@mock.patch("airflow.secrets.metastore.MetastoreBackend.get_connections")
@mock.patch("airflow.secrets.environment_variables.EnvironmentVariablesBackend.get_connections")
def test_get_connections_first_try(self, mock_env_get, mock_meta_get):
mock_env_get.side_effect = [["something"]] # returns nonempty list
get_connections("fake_conn_id")
Connection.get_connections_from_secrets("fake_conn_id")
mock_env_get.assert_called_once_with(conn_id="fake_conn_id")
mock_meta_get.not_called()

Expand Down Expand Up @@ -85,7 +85,7 @@ def test_backend_fallback_to_env_var(self, mock_get_uri):
backend_classes = [backend.__class__.__name__ for backend in backends]
self.assertIn('SystemsManagerParameterStoreBackend', backend_classes)

uri = get_connections(conn_id="test_mysql")
uri = Connection.get_connections_from_secrets(conn_id="test_mysql")

# Assert that SystemsManagerParameterStoreBackend.get_conn_uri was called
mock_get_uri.assert_called_once_with(conn_id='test_mysql')
Expand All @@ -109,7 +109,7 @@ def test_get_variable_second_try(self, mock_env_get, mock_meta_get):
Metastore DB
"""
mock_env_get.return_value = None
get_variable("fake_var_key")
Variable.get_variable_from_secrets("fake_var_key")
mock_meta_get.assert_called_once_with(key="fake_var_key")
mock_env_get.assert_called_once_with(key="fake_var_key")

Expand All @@ -121,7 +121,7 @@ def test_get_variable_first_try(self, mock_env_get, mock_meta_get):
Metastore DB
"""
mock_env_get.return_value = [["something"]] # returns nonempty list
get_variable("fake_var_key")
Variable.get_variable_from_secrets("fake_var_key")
mock_env_get.assert_called_once_with(key="fake_var_key")
mock_meta_get.not_called()

Expand Down

0 comments on commit 5a2babb

Please sign in to comment.