From e5b0b4efeae2ccfc338dbd87442879b3c0bef895 Mon Sep 17 00:00:00 2001 From: Albert Okiri Date: Mon, 25 Mar 2024 19:04:37 +0300 Subject: [PATCH 1/4] cohere 5.0.0 support --- airflow/providers/cohere/hooks/cohere.py | 37 ++++++++++++++++--- .../providers/cohere/operators/embedding.py | 22 ++++++++++- airflow/providers/cohere/provider.yaml | 2 +- generated/provider_dependencies.json | 2 +- tests/providers/cohere/hooks/test_cohere.py | 8 ++-- .../cohere/operators/test_embedding.py | 10 ++--- 6 files changed, 61 insertions(+), 20 deletions(-) diff --git a/airflow/providers/cohere/hooks/cohere.py b/airflow/providers/cohere/hooks/cohere.py index 2ce40c74d1e8e..052bdc352e05b 100644 --- a/airflow/providers/cohere/hooks/cohere.py +++ b/airflow/providers/cohere/hooks/cohere.py @@ -17,13 +17,18 @@ # under the License. from __future__ import annotations +import warnings from functools import cached_property -from typing import Any +from typing import TYPE_CHECKING, Any import cohere +from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.hooks.base import BaseHook +if TYPE_CHECKING: + from cohere.core.request_options import RequestOptions + class CohereHook(BaseHook): """ @@ -34,6 +39,17 @@ class CohereHook(BaseHook): :param conn_id: :ref:`Cohere connection id ` :param timeout: Request timeout in seconds. :param max_retries: Maximal number of retries for requests. + :param request_options: Request-specific configuration. + Fields: + - timeout_in_seconds: int. The number of seconds to await an API call before timing out. + + - max_retries: int. The max number of retries to attempt if the API call fails. + + - additional_headers: typing.Dict[str, typing.Any]. A dictionary containing additional parameters to spread into the request's header dict + + - additional_query_parameters: typing.Dict[str, typing.Any]. A dictionary containing additional parameters to spread into the request's query parameters dict + + - additional_body_parameters: typing.Dict[str, typing.Any]. A dictionary containing additional parameters to spread into the request's body parameters dict """ conn_name_attr = "conn_id" @@ -46,23 +62,30 @@ def __init__( conn_id: str = default_conn_name, timeout: int | None = None, max_retries: int | None = None, + request_options: RequestOptions | None = None, ) -> None: super().__init__() self.conn_id = conn_id self.timeout = timeout self.max_retries = max_retries + self.request_options = request_options + if self.max_retries: + warnings.warn( + "Argument `max_retries` is deprecated. Please use `request_options` dict for function-specific request configuration instead.", + AirflowProviderDeprecationWarning, + stacklevel=2, + ) + self.request_options = {"max_retries": self.max_retries} if self.request_options is None else self.request_options.update({"max_retries": self.max_retries}) @cached_property def get_conn(self) -> cohere.Client: # type: ignore[override] conn = self.get_connection(self.conn_id) - return cohere.Client( - api_key=conn.password, timeout=self.timeout, max_retries=self.max_retries, api_url=conn.host - ) + return cohere.Client(api_key=conn.password, timeout=self.timeout, base_url=conn.host) def create_embeddings( self, texts: list[str], model: str = "embed-multilingual-v2.0" ) -> list[list[float]]: - response = self.get_conn.embed(texts=texts, model=model) + response = self.get_conn.embed(texts=texts, model=model, request_options=self.request_options) embeddings = response.embeddings return embeddings @@ -76,8 +99,10 @@ def get_ui_field_behaviour(cls) -> dict[str, Any]: } def test_connection(self) -> tuple[bool, str]: + if self.max_retries: + self.request_options.update(max_retries=self.max_retries) try: - self.get_conn.generate("Test", max_tokens=10) + self.get_conn.generate(prompt="Test", max_tokens=10, request_options=self.request_options) return True, "Connection established" except Exception as e: return False, str(e) diff --git a/airflow/providers/cohere/operators/embedding.py b/airflow/providers/cohere/operators/embedding.py index dba95e7e8f661..25804101771bc 100644 --- a/airflow/providers/cohere/operators/embedding.py +++ b/airflow/providers/cohere/operators/embedding.py @@ -24,6 +24,8 @@ from airflow.providers.cohere.hooks.cohere import CohereHook if TYPE_CHECKING: + from cohere.core.request_options import RequestOptions + from airflow.utils.context import Context @@ -39,6 +41,17 @@ class CohereEmbeddingOperator(BaseOperator): information for Cohere. Defaults to "cohere_default". :param timeout: Timeout in seconds for Cohere API. :param max_retries: Number of times to retry before failing. + :param request_options: Request-specific configuration. + Fields: + - timeout_in_seconds: int. The number of seconds to await an API call before timing out. + + - max_retries: int. The max number of retries to attempt if the API call fails. + + - additional_headers: typing.Dict[str, typing.Any]. A dictionary containing additional parameters to spread into the request's header dict + + - additional_query_parameters: typing.Dict[str, typing.Any]. A dictionary containing additional parameters to spread into the request's query parameters dict + + - additional_body_parameters: typing.Dict[str, typing.Any]. A dictionary containing additional parameters to spread into the request's body parameters dict """ template_fields: Sequence[str] = ("input_text",) @@ -49,6 +62,7 @@ def __init__( conn_id: str = CohereHook.default_conn_name, timeout: int | None = None, max_retries: int | None = None, + request_options: RequestOptions | None = None, **kwargs: Any, ): super().__init__(**kwargs) @@ -58,11 +72,17 @@ def __init__( self.input_text = input_text self.timeout = timeout self.max_retries = max_retries + self.request_options = request_options @cached_property def hook(self) -> CohereHook: """Return an instance of the CohereHook.""" - return CohereHook(conn_id=self.conn_id, timeout=self.timeout, max_retries=self.max_retries) + return CohereHook( + conn_id=self.conn_id, + timeout=self.timeout, + max_retries=self.max_retries, + request_options=self.request_options, + ) def execute(self, context: Context) -> list[list[float]]: """Embed texts using Cohere embed services.""" diff --git a/airflow/providers/cohere/provider.yaml b/airflow/providers/cohere/provider.yaml index 43d3f35372183..89c48d73a71ab 100644 --- a/airflow/providers/cohere/provider.yaml +++ b/airflow/providers/cohere/provider.yaml @@ -42,7 +42,7 @@ integrations: dependencies: - apache-airflow>=2.6.0 - - cohere>=4.37,<5 + - cohere>=5.0.0 hooks: - integration-name: Cohere diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index 284ff97cc7616..93399ae5b0ff5 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -330,7 +330,7 @@ "cohere": { "deps": [ "apache-airflow>=2.6.0", - "cohere>=4.37,<5" + "cohere>=5.0.0" ], "devel-deps": [], "cross-providers-deps": [], diff --git a/tests/providers/cohere/hooks/test_cohere.py b/tests/providers/cohere/hooks/test_cohere.py index 8f566ec0c6405..893503959c0a8 100644 --- a/tests/providers/cohere/hooks/test_cohere.py +++ b/tests/providers/cohere/hooks/test_cohere.py @@ -31,16 +31,14 @@ class TestCohereHook: def test__get_api_key(self): api_key = "test" - api_url = "http://some_host.com" + base_url = "http://some_host.com" timeout = 150 max_retries = 5 with patch.object( CohereHook, "get_connection", - return_value=Connection(conn_type="cohere", password=api_key, host=api_url), + return_value=Connection(conn_type="cohere", password=api_key, host=base_url), ), patch("cohere.Client") as client: hook = CohereHook(timeout=timeout, max_retries=max_retries) _ = hook.get_conn - client.assert_called_once_with( - api_key=api_key, timeout=timeout, max_retries=max_retries, api_url=api_url - ) + client.assert_called_once_with(api_key=api_key, timeout=timeout, base_url=base_url) diff --git a/tests/providers/cohere/operators/test_embedding.py b/tests/providers/cohere/operators/test_embedding.py index 32dd83aa2614a..c150bfc0bfc7d 100644 --- a/tests/providers/cohere/operators/test_embedding.py +++ b/tests/providers/cohere/operators/test_embedding.py @@ -29,18 +29,18 @@ def test_cohere_embedding_operator(cohere_client, get_connection): Test Cohere client is getting called with the correct key and that the execute methods returns expected response. """ - embedded_obj = [1, 2, 3] + embedded_obj = [[1.0, 2.0, 3.0]] class resp: embeddings = embedded_obj api_key = "test" - api_url = "http://some_host.com" + base_url = "http://some_host.com" timeout = 150 max_retries = 5 texts = ["On Kernel-Target Alignment. We describe a family of global optimization procedures"] - get_connection.return_value = Connection(conn_type="cohere", password=api_key, host=api_url) + get_connection.return_value = Connection(conn_type="cohere", password=api_key, host=base_url) client_obj = MagicMock() cohere_client.return_value = client_obj client_obj.embed.return_value = resp @@ -50,7 +50,5 @@ class resp: ) val = op.execute(context={}) - cohere_client.assert_called_once_with( - api_key=api_key, api_url=api_url, timeout=timeout, max_retries=max_retries - ) + cohere_client.assert_called_once_with(api_key=api_key, base_url=base_url, timeout=timeout) assert val == embedded_obj From c63ebe69238945a230ebe414800eb7eb207cfbf3 Mon Sep 17 00:00:00 2001 From: Albert Okiri Date: Wed, 17 Apr 2024 11:32:34 +0300 Subject: [PATCH 2/4] confine dependency --- airflow/providers/cohere/provider.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/airflow/providers/cohere/provider.yaml b/airflow/providers/cohere/provider.yaml index 578b651ae0923..9850a7f5a9af7 100644 --- a/airflow/providers/cohere/provider.yaml +++ b/airflow/providers/cohere/provider.yaml @@ -44,6 +44,7 @@ integrations: dependencies: - apache-airflow>=2.6.0 - cohere>=5.0.0 + - opentelemetry-proto==1.24.0 hooks: - integration-name: Cohere From f8a2f512b5e828cbeab7bedfb00d730b33249935 Mon Sep 17 00:00:00 2001 From: Albert Okiri Date: Wed, 17 Apr 2024 11:47:58 +0300 Subject: [PATCH 3/4] update provider dependencies --- generated/provider_dependencies.json | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index 349c96eeecebf..098f261c042f3 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -330,7 +330,8 @@ "cohere": { "deps": [ "apache-airflow>=2.6.0", - "cohere>=5.0.0" + "cohere>=5.0.0", + "opentelemetry-proto==1.24.0" ], "devel-deps": [], "cross-providers-deps": [], From f801132dd3d4aa3ef5e2faaa351a1a6c3c876358 Mon Sep 17 00:00:00 2001 From: Albert Okiri Date: Fri, 19 Apr 2024 12:07:04 +0300 Subject: [PATCH 4/4] fix mypy errors --- airflow/providers/cohere/hooks/cohere.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/airflow/providers/cohere/hooks/cohere.py b/airflow/providers/cohere/hooks/cohere.py index 052bdc352e05b..f80c51ce02c7f 100644 --- a/airflow/providers/cohere/hooks/cohere.py +++ b/airflow/providers/cohere/hooks/cohere.py @@ -19,16 +19,13 @@ import warnings from functools import cached_property -from typing import TYPE_CHECKING, Any +from typing import Any import cohere from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.hooks.base import BaseHook -if TYPE_CHECKING: - from cohere.core.request_options import RequestOptions - class CohereHook(BaseHook): """ @@ -62,7 +59,7 @@ def __init__( conn_id: str = default_conn_name, timeout: int | None = None, max_retries: int | None = None, - request_options: RequestOptions | None = None, + request_options: dict | None = None, ) -> None: super().__init__() self.conn_id = conn_id @@ -75,7 +72,11 @@ def __init__( AirflowProviderDeprecationWarning, stacklevel=2, ) - self.request_options = {"max_retries": self.max_retries} if self.request_options is None else self.request_options.update({"max_retries": self.max_retries}) + self.request_options = ( + {"max_retries": self.max_retries} + if self.request_options is None + else self.request_options.update({"max_retries": self.max_retries}) + ) @cached_property def get_conn(self) -> cohere.Client: # type: ignore[override] @@ -84,7 +85,7 @@ def get_conn(self) -> cohere.Client: # type: ignore[override] def create_embeddings( self, texts: list[str], model: str = "embed-multilingual-v2.0" - ) -> list[list[float]]: + ) -> list[list[float]] | cohere.EmbedByTypeResponseEmbeddings: response = self.get_conn.embed(texts=texts, model=model, request_options=self.request_options) embeddings = response.embeddings return embeddings