Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Opensearch operator set connection type #39788

Merged
merged 12 commits into from
Jun 8, 2024
48 changes: 44 additions & 4 deletions airflow/providers/opensearch/hooks/opensearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@

from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook
from airflow.utils.module_loading import import_string
from airflow.utils.strings import to_boolean

DEFAULT_CONN_TYPES = frozenset(
{"RequestsHttpConnection", "Urllib3HttpConnection", "AsyncHttpConnection", "PoolingConnection"}
)


class OpenSearchHook(BaseHook):
Expand All @@ -40,15 +46,49 @@ class OpenSearchHook(BaseHook):
conn_type = "opensearch"
hook_name = "OpenSearch Hook"

def __init__(self, open_search_conn_id: str, log_query: bool, **kwargs: Any):
def __init__(
self,
open_search_conn_id: str,
log_query: bool,
open_search_conn_class: str = "RequestsHttpConnection",
**kwargs: Any,
):
super().__init__(**kwargs)
self.conn_id = open_search_conn_id
self.log_query = log_query

self.use_ssl = self.conn.extra_dejson.get("use_ssl", False)
self.verify_certs = self.conn.extra_dejson.get("verify_certs", False)
self.use_ssl = to_boolean(str(self.conn.extra_dejson.get("use_ssl", False)))
self.verify_certs = to_boolean(str(self.conn.extra_dejson.get("verify_certs", False)))
self.connection_class = self._load_conn_type(open_search_conn_class)
self.__SERVICE = "es"

def _load_conn_type(self, module_name: str | None) -> Any:
"""
Check if the connection type module is listed in 'DEFAULT_CONN_TYPES' and load it.

This method protects against the execution of random modules.
"""
if module_name:
if module_name in DEFAULT_CONN_TYPES:
try:
module_name_full = f"opensearchpy.{module_name}"
module = import_string(module_name_full)
self.log.info("Loaded connection type: %s", module_name)
return module
except ImportError as error:
self.log.debug("Cannot import connection type '%s' due to: %s", module_name, error)
raise AirflowException(error)
else:
self.log.warning(
"Skipping import of connection type '%s'. The class should be listed in ",
module_name
+ "{"
+ ", ".join(map(str, DEFAULT_CONN_TYPES))
+ "}. Defaulting to RequestsHttpConnection",
)
# fallback
return RequestsHttpConnection

@cached_property
def conn(self):
return self.get_connection(self.conn_id)
Expand All @@ -62,7 +102,7 @@ def client(self) -> OpenSearch:
http_auth=auth,
use_ssl=self.use_ssl,
verify_certs=self.verify_certs,
connection_class=RequestsHttpConnection,
connection_class=self.connection_class,
)
return client

Expand Down
8 changes: 7 additions & 1 deletion airflow/providers/opensearch/operators/opensearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,20 +54,26 @@ def __init__(
search_object: Any | None = None,
index_name: str | None = None,
opensearch_conn_id: str = "opensearch_default",
opensearch_conn_class: str = "RequestsHttpConnection",
log_query: bool = True,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.query = query
self.index_name = index_name
self.opensearch_conn_id = opensearch_conn_id
self.opensearch_conn_class = opensearch_conn_class
self.log_query = log_query
self.search_object = search_object

@cached_property
def hook(self) -> OpenSearchHook:
"""Get an instance of an OpenSearchHook."""
return OpenSearchHook(open_search_conn_id=self.opensearch_conn_id, log_query=self.log_query)
return OpenSearchHook(
open_search_conn_id=self.opensearch_conn_id,
open_search_conn_class=self.opensearch_conn_class,
log_query=self.log_query,
)

def execute(self, context: Context) -> Any:
"""Execute a search against a given index or a Search object on an OpenSearch Cluster."""
Expand Down
35 changes: 35 additions & 0 deletions tests/providers/opensearch/hooks/test_opensearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,21 @@
# under the License.
from __future__ import annotations

from unittest import mock

import pytest
from opensearchpy.connection.http_requests import RequestsHttpConnection
from opensearchpy.connection.http_urllib3 import Urllib3HttpConnection

from airflow.exceptions import AirflowException
from airflow.models import Connection
from airflow.providers.opensearch.hooks.opensearch import OpenSearchHook

pytestmark = pytest.mark.db_test


MOCK_SEARCH_RETURN = {"status": "test"}
DEFAULT_CONN = RequestsHttpConnection


class TestOpenSearchHook:
Expand All @@ -46,3 +52,32 @@ def test_delete_check_parameters(self):
hook = OpenSearchHook(open_search_conn_id="opensearch_default", log_query=True)
with pytest.raises(AirflowException, match="must include one of either a query or a document id"):
hook.delete(index_name="test_index")

@mock.patch("airflow.hooks.base.BaseHook.get_connection")
def test_hook_param_bool(self, mock_get_connection):
mock_conn = Connection(
conn_id="opensearch_default", extra={"use_ssl": "True", "verify_certs": "True"}
)
mock_get_connection.return_value = mock_conn
hook = OpenSearchHook(open_search_conn_id="opensearch_default", log_query=True)

assert isinstance(hook.use_ssl, bool)
assert isinstance(hook.verify_certs, bool)

def test_load_conn_param(self, mock_hook):
hook_default = OpenSearchHook(open_search_conn_id="opensearch_default", log_query=True)
assert issubclass(hook_default.connection_class, DEFAULT_CONN)

hook_Urllib3 = OpenSearchHook(
open_search_conn_id="opensearch_default",
log_query=True,
open_search_conn_class="Urllib3HttpConnection",
)
assert issubclass(hook_Urllib3.connection_class, Urllib3HttpConnection)

hook_invalid_conn = OpenSearchHook(
open_search_conn_id="opensearch_default",
log_query=True,
open_search_conn_class="invalid_connection",
)
assert issubclass(hook_invalid_conn.connection_class, DEFAULT_CONN)