Skip to content

Commit

Permalink
plugins: load plugins from providers
Browse files Browse the repository at this point in the history
Signed-off-by: Maciej Obuchowski <[email protected]>
  • Loading branch information
mobuchowski committed Jul 19, 2023
1 parent 0fbef49 commit a126e8f
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 7 deletions.
9 changes: 4 additions & 5 deletions airflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,18 +106,17 @@ def __getattr__(name: str):
return val


if not settings.LAZY_LOAD_PLUGINS:
from airflow import plugins_manager

plugins_manager.ensure_plugins_loaded()

if not settings.LAZY_LOAD_PROVIDERS:
from airflow import providers_manager

manager = providers_manager.ProvidersManager()
manager.initialize_providers_list()
manager.initialize_providers_hooks()
manager.initialize_providers_extra_links()
if not settings.LAZY_LOAD_PLUGINS:
from airflow import plugins_manager

plugins_manager.ensure_plugins_loaded()


# This is never executed, but tricks static analyzers (PyDev, PyCharm,)
Expand Down
33 changes: 32 additions & 1 deletion airflow/plugins_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from airflow import settings
from airflow.utils.entry_points import entry_points_with_dist
from airflow.utils.file import find_path_from_directory
from airflow.utils.module_loading import qualname
from airflow.utils.module_loading import import_string, qualname

if TYPE_CHECKING:
from airflow.hooks.base import BaseHook
Expand All @@ -50,6 +50,7 @@
import_errors: dict[str, str] = {}

plugins: list[AirflowPlugin] | None = None
loaded_plugins: set[str] = set()

# Plugin components to integrate as modules
registered_hooks: list[BaseHook] | None = None
Expand Down Expand Up @@ -205,10 +206,16 @@ def is_valid_plugin(plugin_obj):
def register_plugin(plugin_instance):
"""
Start plugin load and register it after success initialization.
If plugin is already registered, do nothing.
:param plugin_instance: subclass of AirflowPlugin
"""
global plugins

if plugin_instance.__name__ in loaded_plugins:
return

loaded_plugins.add(plugin_instance.__name__)
plugin_instance.on_load()
plugins.append(plugin_instance)

Expand Down Expand Up @@ -267,6 +274,27 @@ def load_plugins_from_plugin_directory():
import_errors[file_path] = str(e)


def load_providers_plugins():
from airflow.providers_manager import ProvidersManager

global import_errors
log.debug("Loading plugins from providers")
providers_manager = ProvidersManager()
providers_manager.initialize_providers_plugins()
for plugin in providers_manager.plugins:
log.debug("Importing plugin %s from class %s", plugin.name, plugin.plugin_class)

try:
plugin_instance = import_string(plugin.plugin_class)
if not is_valid_plugin(plugin_instance):
log.warning("Plugin %s is not a valid plugin", plugin.name)
continue
register_plugin(plugin_instance)
except ImportError:
log.exception("Failed to load plugin %s from class name %s", plugin.name, plugin.plugin_class)
continue


def make_module(name: str, objects: list[Any]):
"""Creates new module."""
if not objects:
Expand Down Expand Up @@ -306,6 +334,9 @@ def ensure_plugins_loaded():
load_plugins_from_plugin_directory()
load_entrypoint_plugins()

if not settings.LAZY_LOAD_PROVIDERS:
load_providers_plugins()

# We don't do anything with these for now, but we want to keep track of
# them so we can integrate them in to the UI's Connection screens
for plugin in plugins:
Expand Down
35 changes: 35 additions & 0 deletions airflow/providers_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,14 @@ class TriggerInfo(NamedTuple):
integration_name: str


class PluginInfo(NamedTuple):
"""Plugin class, name and provider it comes from."""

name: str
plugin_class: str
provider_name: str


class HookInfo(NamedTuple):
"""Hook information."""

Expand Down Expand Up @@ -407,6 +415,8 @@ def __init__(self):
self._customized_form_fields_schema_validator = (
_create_customized_form_field_behaviours_schema_validator()
)
# Set of plugins contained in providers
self._plugins_set: set[PluginInfo] = set()

@provider_info_cache("list")
def initialize_providers_list(self):
Expand Down Expand Up @@ -489,6 +499,11 @@ def initialize_providers_auth_backends(self):
self.initialize_providers_list()
self._discover_auth_backends()

@provider_info_cache("plugins")
def initialize_providers_plugins(self):
self.initialize_providers_list()
self._discover_plugins()

def _discover_all_providers_from_packages(self) -> None:
"""
Discover all providers by scanning packages installed.
Expand Down Expand Up @@ -991,6 +1006,20 @@ def _discover_config(self) -> None:
if provider.data.get("config"):
self._provider_configs[provider_package] = provider.data.get("config")

def _discover_plugins(self) -> None:
"""Retrieve all plugins defined in the providers."""
for provider_package, provider in self._provider_dict.items():
if provider.data.get("plugins"):
for plugin_dict in provider.data["plugins"]:
if _correctness_check(provider_package, plugin_dict["plugin-class"], provider):
self._plugins_set.add(
PluginInfo(
name=plugin_dict["name"],
plugin_class=plugin_dict["plugin-class"],
provider_name=provider_package,
)
)

@provider_info_cache("triggers")
def initialize_providers_triggers(self):
"""Initialization of providers triggers."""
Expand Down Expand Up @@ -1029,6 +1058,12 @@ def hooks(self) -> MutableMapping[str, HookInfo | None]:
# When we return hooks here it will only be used to retrieve hook information
return self._hooks_lazy_dict

@property
def plugins(self) -> list[PluginInfo]:
"""Returns information about plugins available in providers."""
self.initialize_providers_plugins()
return sorted(self._plugins_set, key=lambda x: x.plugin_class)

@property
def taskflow_decorators(self) -> dict[str, TaskDecorator]:
self.initialize_providers_taskflow_decorator()
Expand Down
30 changes: 29 additions & 1 deletion tests/always/test_providers_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,13 @@
from wtforms import BooleanField, Field, StringField

from airflow.exceptions import AirflowOptionalProviderFeatureException
from airflow.providers_manager import HookClassProvider, LazyDictWithCache, ProviderInfo, ProvidersManager
from airflow.providers_manager import (
HookClassProvider,
LazyDictWithCache,
PluginInfo,
ProviderInfo,
ProvidersManager,
)


class TestProviderManager:
Expand Down Expand Up @@ -157,6 +163,28 @@ def test_already_registered_conn_type_in_provide(self):
" and 'airflow.providers.dummy.hooks.dummy.DummyHook2'."
) in self._caplog.records[0].message

def test_providers_manager_register_plugins(self):
providers_manager = ProvidersManager()
providers_manager._provider_dict["apache-airflow-providers-apache-hive"] = ProviderInfo(
version="0.0.1",
data={
"plugins": [
{
"name": "plugin1",
"plugin-class": "airflow.providers.apache.hive.plugins.hive.HivePlugin",
}
]
},
package_or_source="package",
)
providers_manager._discover_plugins()
assert len(providers_manager._plugins_set) == 1
assert providers_manager._plugins_set.pop() == PluginInfo(
name="plugin1",
plugin_class="airflow.providers.apache.hive.plugins.hive.HivePlugin",
provider_name="apache-airflow-providers-apache-hive",
)

def test_hooks(self):
with pytest.warns(expected_warning=None) as warning_records:
with self._caplog.at_level(logging.WARNING):
Expand Down
33 changes: 33 additions & 0 deletions tests/plugins/test_plugins_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,13 @@ class AirflowNoMenuViewsPlugin(AirflowPlugin):


class TestPluginsManager:
@pytest.fixture(autouse=True, scope="function")
def clean_plugins(self):
from airflow import plugins_manager

plugins_manager.loaded_plugins = set()
plugins_manager.plugins = []

def test_no_log_when_no_plugins(self, caplog):

with mock_plugin_manager(plugins=[]):
Expand Down Expand Up @@ -378,6 +385,32 @@ def test_registering_plugin_listeners(self):
assert get_listener_manager().has_listeners
assert get_listener_manager().pm.get_plugins().pop().__name__ == "tests.listeners.empty_listener"

def test_should_import_plugin_from_providers(self):
from airflow import plugins_manager

with mock.patch("airflow.plugins_manager.plugins", []):
assert len(plugins_manager.plugins) == 0
plugins_manager.load_providers_plugins()
assert len(plugins_manager.plugins) >= 2

def test_does_not_double_import_entrypoint_provider_plugins(self):
from airflow import plugins_manager

mock_entrypoint = mock.Mock()
mock_entrypoint.name = "test-entrypoint-plugin"
mock_entrypoint.module = "module_name_plugin"

mock_dist = mock.Mock()
mock_dist.metadata = {"Name": "test-entrypoint-plugin"}
mock_dist.version = "1.0.0"
mock_dist.entry_points = [mock_entrypoint]

with mock.patch("airflow.plugins_manager.plugins", []):
assert len(plugins_manager.plugins) == 0
plugins_manager.load_entrypoint_plugins()
plugins_manager.load_providers_plugins()
assert len(plugins_manager.plugins) == 2


class TestPluginsDirectorySource:
def test_should_return_correct_path_name(self):
Expand Down

0 comments on commit a126e8f

Please sign in to comment.