Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Opensearch operator set connection type #39788

Merged
merged 12 commits into from
Jun 8, 2024
21 changes: 16 additions & 5 deletions airflow/providers/opensearch/hooks/opensearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,16 @@

import json
from functools import cached_property
from typing import Any
from typing import TYPE_CHECKING, Any

from opensearchpy import OpenSearch, RequestsHttpConnection

if TYPE_CHECKING:
from opensearchpy import Connection as OpenSearchConnectionClass

from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook
from airflow.utils.strings import to_boolean


class OpenSearchHook(BaseHook):
Expand All @@ -40,13 +44,20 @@ class OpenSearchHook(BaseHook):
conn_type = "opensearch"
hook_name = "OpenSearch Hook"

def __init__(self, open_search_conn_id: str, log_query: bool, **kwargs: Any):
def __init__(
self,
open_search_conn_id: str,
log_query: bool,
open_search_conn_class: type[OpenSearchConnectionClass] | None = RequestsHttpConnection,
**kwargs: Any,
):
super().__init__(**kwargs)
self.conn_id = open_search_conn_id
self.log_query = log_query

self.use_ssl = self.conn.extra_dejson.get("use_ssl", False)
self.verify_certs = self.conn.extra_dejson.get("verify_certs", False)
self.use_ssl = to_boolean(str(self.conn.extra_dejson.get("use_ssl", False)))
self.verify_certs = to_boolean(str(self.conn.extra_dejson.get("verify_certs", False)))
self.connection_class = open_search_conn_class
self.__SERVICE = "es"

@cached_property
Expand All @@ -62,7 +73,7 @@ def client(self) -> OpenSearch:
http_auth=auth,
use_ssl=self.use_ssl,
verify_certs=self.verify_certs,
connection_class=RequestsHttpConnection,
connection_class=self.connection_class,
)
return client

Expand Down
12 changes: 11 additions & 1 deletion airflow/providers/opensearch/operators/opensearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,16 @@
from functools import cached_property
from typing import TYPE_CHECKING, Any, Sequence

from opensearchpy import RequestsHttpConnection
from opensearchpy.exceptions import OpenSearchException

from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.providers.opensearch.hooks.opensearch import OpenSearchHook

if TYPE_CHECKING:
from opensearchpy import Connection as OpenSearchConnectionClass

from airflow.utils.context import Context


Expand All @@ -42,6 +45,7 @@ class OpenSearchQueryOperator(BaseOperator):
:param search_object: A Search object from opensearch-dsl.
:param index_name: The name of the index to search for documents.
:param opensearch_conn_id: opensearch connection to use
:param opensearch_conn_class: opensearch connection class to use
:param log_query: Whether to log the query used. Defaults to True and logs query used.
"""

Expand All @@ -54,20 +58,26 @@ def __init__(
search_object: Any | None = None,
index_name: str | None = None,
opensearch_conn_id: str = "opensearch_default",
opensearch_conn_class: type[OpenSearchConnectionClass] | None = RequestsHttpConnection,
log_query: bool = True,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.query = query
self.index_name = index_name
self.opensearch_conn_id = opensearch_conn_id
self.opensearch_conn_class = opensearch_conn_class
self.log_query = log_query
self.search_object = search_object

@cached_property
def hook(self) -> OpenSearchHook:
"""Get an instance of an OpenSearchHook."""
return OpenSearchHook(open_search_conn_id=self.opensearch_conn_id, log_query=self.log_query)
return OpenSearchHook(
open_search_conn_id=self.opensearch_conn_id,
open_search_conn_class=self.opensearch_conn_class,
log_query=self.log_query,
)

def execute(self, context: Context) -> Any:
"""Execute a search against a given index or a Search object on an OpenSearch Cluster."""
Expand Down
28 changes: 28 additions & 0 deletions tests/providers/opensearch/hooks/test_opensearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,21 @@
# under the License.
from __future__ import annotations

from unittest import mock

import opensearchpy
import pytest
from opensearchpy import Urllib3HttpConnection

from airflow.exceptions import AirflowException
from airflow.models import Connection
from airflow.providers.opensearch.hooks.opensearch import OpenSearchHook

pytestmark = pytest.mark.db_test


MOCK_SEARCH_RETURN = {"status": "test"}
DEFAULT_CONN = opensearchpy.connection.http_requests.RequestsHttpConnection


class TestOpenSearchHook:
Expand All @@ -46,3 +52,25 @@ def test_delete_check_parameters(self):
hook = OpenSearchHook(open_search_conn_id="opensearch_default", log_query=True)
with pytest.raises(AirflowException, match="must include one of either a query or a document id"):
hook.delete(index_name="test_index")

@mock.patch("airflow.hooks.base.BaseHook.get_connection")
def test_hook_param_bool(self, mock_get_connection):
mock_conn = Connection(
conn_id="opensearch_default", extra={"use_ssl": "True", "verify_certs": "True"}
)
mock_get_connection.return_value = mock_conn
hook = OpenSearchHook(open_search_conn_id="opensearch_default", log_query=True)

assert isinstance(hook.use_ssl, bool)
assert isinstance(hook.verify_certs, bool)

def test_load_conn_param(self, mock_hook):
hook_default = OpenSearchHook(open_search_conn_id="opensearch_default", log_query=True)
assert hook_default.connection_class == DEFAULT_CONN

hook_Urllib3 = OpenSearchHook(
open_search_conn_id="opensearch_default",
log_query=True,
open_search_conn_class=Urllib3HttpConnection,
)
assert hook_Urllib3.connection_class == Urllib3HttpConnection