Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

plugins: load plugins from providers #32692

Merged
merged 1 commit into from
Jul 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions airflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,18 +106,17 @@ def __getattr__(name: str):
return val


if not settings.LAZY_LOAD_PLUGINS:
from airflow import plugins_manager

plugins_manager.ensure_plugins_loaded()

if not settings.LAZY_LOAD_PROVIDERS:
from airflow import providers_manager

manager = providers_manager.ProvidersManager()
manager.initialize_providers_list()
manager.initialize_providers_hooks()
manager.initialize_providers_extra_links()
if not settings.LAZY_LOAD_PLUGINS:
from airflow import plugins_manager

plugins_manager.ensure_plugins_loaded()


# This is never executed, but tricks static analyzers (PyDev, PyCharm,)
Expand Down
32 changes: 31 additions & 1 deletion airflow/plugins_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from airflow import settings
from airflow.utils.entry_points import entry_points_with_dist
from airflow.utils.file import find_path_from_directory
from airflow.utils.module_loading import qualname
from airflow.utils.module_loading import import_string, qualname

if TYPE_CHECKING:
from airflow.hooks.base import BaseHook
Expand All @@ -50,6 +50,7 @@
import_errors: dict[str, str] = {}

plugins: list[AirflowPlugin] | None = None
loaded_plugins: set[str] = set()

# Plugin components to integrate as modules
registered_hooks: list[BaseHook] | None = None
Expand Down Expand Up @@ -205,10 +206,16 @@ def is_valid_plugin(plugin_obj):
def register_plugin(plugin_instance):
"""
Start plugin load and register it after success initialization.
If plugin is already registered, do nothing.

:param plugin_instance: subclass of AirflowPlugin
"""
global plugins

if plugin_instance.name in loaded_plugins:
return

loaded_plugins.add(plugin_instance.name)
plugin_instance.on_load()
plugins.append(plugin_instance)

Expand Down Expand Up @@ -267,6 +274,26 @@ def load_plugins_from_plugin_directory():
import_errors[file_path] = str(e)


def load_providers_plugins():
from airflow.providers_manager import ProvidersManager

log.debug("Loading plugins from providers")
providers_manager = ProvidersManager()
providers_manager.initialize_providers_plugins()
for plugin in providers_manager.plugins:
log.debug("Importing plugin %s from class %s", plugin.name, plugin.plugin_class)

try:
plugin_instance = import_string(plugin.plugin_class)
if not is_valid_plugin(plugin_instance):
log.warning("Plugin %s is not a valid plugin", plugin.name)
continue
register_plugin(plugin_instance)
except ImportError:
log.exception("Failed to load plugin %s from class name %s", plugin.name, plugin.plugin_class)
continue


def make_module(name: str, objects: list[Any]):
"""Creates new module."""
if not objects:
Expand Down Expand Up @@ -306,6 +333,9 @@ def ensure_plugins_loaded():
load_plugins_from_plugin_directory()
load_entrypoint_plugins()

if not settings.LAZY_LOAD_PROVIDERS:
load_providers_plugins()

# We don't do anything with these for now, but we want to keep track of
# them so we can integrate them in to the UI's Connection screens
for plugin in plugins:
Expand Down
36 changes: 36 additions & 0 deletions airflow/providers_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,14 @@ class TriggerInfo(NamedTuple):
integration_name: str


class PluginInfo(NamedTuple):
"""Plugin class, name and provider it comes from."""

name: str
plugin_class: str
provider_name: str


class HookInfo(NamedTuple):
"""Hook information."""

Expand Down Expand Up @@ -421,6 +429,8 @@ def __init__(self):
self._customized_form_fields_schema_validator = (
_create_customized_form_field_behaviours_schema_validator()
)
# Set of plugins contained in providers
self._plugins_set: set[PluginInfo] = set()

@provider_info_cache("list")
def initialize_providers_list(self):
Expand Down Expand Up @@ -516,6 +526,11 @@ def initialize_providers_auth_backends(self):
self.initialize_providers_list()
self._discover_auth_backends()

@provider_info_cache("plugins")
def initialize_providers_plugins(self):
self.initialize_providers_list()
self._discover_plugins()

def _discover_all_providers_from_packages(self) -> None:
"""
Discover all providers by scanning packages installed.
Expand Down Expand Up @@ -1024,6 +1039,21 @@ def _discover_config(self) -> None:
if provider.data.get("config"):
self._provider_configs[provider_package] = provider.data.get("config")

def _discover_plugins(self) -> None:
"""Retrieve all plugins defined in the providers."""
for provider_package, provider in self._provider_dict.items():
for plugin_dict in provider.data.get("plugins", ()):
if not _correctness_check(provider_package, plugin_dict["plugin-class"], provider):
log.warning("Plugin not loaded due to above correctness check problem.")
continue
self._plugins_set.add(
PluginInfo(
name=plugin_dict["name"],
plugin_class=plugin_dict["plugin-class"],
provider_name=provider_package,
)
)

@provider_info_cache("triggers")
def initialize_providers_triggers(self):
"""Initialization of providers triggers."""
Expand Down Expand Up @@ -1062,6 +1092,12 @@ def hooks(self) -> MutableMapping[str, HookInfo | None]:
# When we return hooks here it will only be used to retrieve hook information
return self._hooks_lazy_dict

@property
def plugins(self) -> list[PluginInfo]:
"""Returns information about plugins available in providers."""
self.initialize_providers_plugins()
return sorted(self._plugins_set, key=lambda x: x.plugin_class)

@property
def taskflow_decorators(self) -> dict[str, TaskDecorator]:
self.initialize_providers_taskflow_decorator()
Expand Down
30 changes: 29 additions & 1 deletion tests/always/test_providers_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,13 @@
from wtforms import BooleanField, Field, StringField

from airflow.exceptions import AirflowOptionalProviderFeatureException
from airflow.providers_manager import HookClassProvider, LazyDictWithCache, ProviderInfo, ProvidersManager
from airflow.providers_manager import (
HookClassProvider,
LazyDictWithCache,
PluginInfo,
ProviderInfo,
ProvidersManager,
)


class TestProviderManager:
Expand Down Expand Up @@ -157,6 +163,28 @@ def test_already_registered_conn_type_in_provide(self):
" and 'airflow.providers.dummy.hooks.dummy.DummyHook2'."
) in self._caplog.records[0].message

def test_providers_manager_register_plugins(self):
providers_manager = ProvidersManager()
providers_manager._provider_dict["apache-airflow-providers-apache-hive"] = ProviderInfo(
version="0.0.1",
data={
"plugins": [
{
"name": "plugin1",
"plugin-class": "airflow.providers.apache.hive.plugins.hive.HivePlugin",
}
]
},
package_or_source="package",
)
providers_manager._discover_plugins()
assert len(providers_manager._plugins_set) == 1
assert providers_manager._plugins_set.pop() == PluginInfo(
name="plugin1",
plugin_class="airflow.providers.apache.hive.plugins.hive.HivePlugin",
provider_name="apache-airflow-providers-apache-hive",
)

def test_hooks(self):
with pytest.warns(expected_warning=None) as warning_records:
with self._caplog.at_level(logging.WARNING):
Expand Down
33 changes: 33 additions & 0 deletions tests/plugins/test_plugins_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,13 @@ class AirflowNoMenuViewsPlugin(AirflowPlugin):


class TestPluginsManager:
@pytest.fixture(autouse=True, scope="function")
def clean_plugins(self):
from airflow import plugins_manager

plugins_manager.loaded_plugins = set()
plugins_manager.plugins = []

def test_no_log_when_no_plugins(self, caplog):

with mock_plugin_manager(plugins=[]):
Expand Down Expand Up @@ -378,6 +385,32 @@ def test_registering_plugin_listeners(self):
assert get_listener_manager().has_listeners
assert get_listener_manager().pm.get_plugins().pop().__name__ == "tests.listeners.empty_listener"

def test_should_import_plugin_from_providers(self):
from airflow import plugins_manager

with mock.patch("airflow.plugins_manager.plugins", []):
assert len(plugins_manager.plugins) == 0
plugins_manager.load_providers_plugins()
assert len(plugins_manager.plugins) >= 2

def test_does_not_double_import_entrypoint_provider_plugins(self):
from airflow import plugins_manager

mock_entrypoint = mock.Mock()
mock_entrypoint.name = "test-entrypoint-plugin"
mock_entrypoint.module = "module_name_plugin"

mock_dist = mock.Mock()
mock_dist.metadata = {"Name": "test-entrypoint-plugin"}
mock_dist.version = "1.0.0"
mock_dist.entry_points = [mock_entrypoint]

with mock.patch("airflow.plugins_manager.plugins", []):
assert len(plugins_manager.plugins) == 0
plugins_manager.load_entrypoint_plugins()
plugins_manager.load_providers_plugins()
assert len(plugins_manager.plugins) == 2


class TestPluginsDirectorySource:
def test_should_return_correct_path_name(self):
Expand Down