Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cohere 5.0.0 support #38465

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 32 additions & 6 deletions airflow/providers/cohere/hooks/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@
# under the License.
from __future__ import annotations

import warnings
from functools import cached_property
from typing import Any

import cohere

from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.hooks.base import BaseHook


Expand All @@ -34,6 +36,17 @@ class CohereHook(BaseHook):
:param conn_id: :ref:`Cohere connection id <howto/connection:cohere>`
:param timeout: Request timeout in seconds.
:param max_retries: Maximal number of retries for requests.
:param request_options: Request-specific configuration.
Fields:
- timeout_in_seconds: int. The number of seconds to await an API call before timing out.

- max_retries: int. The max number of retries to attempt if the API call fails.

- additional_headers: typing.Dict[str, typing.Any]. A dictionary containing additional parameters to spread into the request's header dict

- additional_query_parameters: typing.Dict[str, typing.Any]. A dictionary containing additional parameters to spread into the request's query parameters dict

- additional_body_parameters: typing.Dict[str, typing.Any]. A dictionary containing additional parameters to spread into the request's body parameters dict
"""

conn_name_attr = "conn_id"
Expand All @@ -46,23 +59,34 @@ def __init__(
conn_id: str = default_conn_name,
timeout: int | None = None,
max_retries: int | None = None,
request_options: dict | None = None,
) -> None:
super().__init__()
self.conn_id = conn_id
self.timeout = timeout
self.max_retries = max_retries
self.request_options = request_options
if self.max_retries:
warnings.warn(
"Argument `max_retries` is deprecated. Please use `request_options` dict for function-specific request configuration instead.",
AirflowProviderDeprecationWarning,
stacklevel=2,
)
self.request_options = (
{"max_retries": self.max_retries}
if self.request_options is None
else self.request_options.update({"max_retries": self.max_retries})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we generally use if/else in this manner to conditionally assign values, which is the intent of if clause here but in else you are updating the self.request_options dict. Which to me seems out of place. It would be much better to have a simplified version of the code here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like self.request_options will be None everytime in else case.

Screenshot 2024-05-21 at 2 47 44 PM

)

@cached_property
def get_conn(self) -> cohere.Client: # type: ignore[override]
conn = self.get_connection(self.conn_id)
return cohere.Client(
api_key=conn.password, timeout=self.timeout, max_retries=self.max_retries, api_url=conn.host
)
return cohere.Client(api_key=conn.password, timeout=self.timeout, base_url=conn.host)

def create_embeddings(
self, texts: list[str], model: str = "embed-multilingual-v2.0"
) -> list[list[float]]:
response = self.get_conn.embed(texts=texts, model=model)
) -> list[list[float]] | cohere.EmbedByTypeResponseEmbeddings:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think return type should be just EmbedResponse ref. Is there any case where we will still return list[list[float]]?

response = self.get_conn.embed(texts=texts, model=model, request_options=self.request_options)
embeddings = response.embeddings
return embeddings

Expand All @@ -76,8 +100,10 @@ def get_ui_field_behaviour(cls) -> dict[str, Any]:
}

def test_connection(self) -> tuple[bool, str]:
if self.max_retries:
self.request_options.update(max_retries=self.max_retries)
try:
self.get_conn.generate("Test", max_tokens=10)
self.get_conn.generate(prompt="Test", max_tokens=10, request_options=self.request_options)
return True, "Connection established"
except Exception as e:
return False, str(e)
22 changes: 21 additions & 1 deletion airflow/providers/cohere/operators/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from airflow.providers.cohere.hooks.cohere import CohereHook

if TYPE_CHECKING:
from cohere.core.request_options import RequestOptions

from airflow.utils.context import Context


Expand All @@ -39,6 +41,17 @@ class CohereEmbeddingOperator(BaseOperator):
information for Cohere. Defaults to "cohere_default".
:param timeout: Timeout in seconds for Cohere API.
:param max_retries: Number of times to retry before failing.
:param request_options: Request-specific configuration.
Fields:
- timeout_in_seconds: int. The number of seconds to await an API call before timing out.

- max_retries: int. The max number of retries to attempt if the API call fails.

- additional_headers: typing.Dict[str, typing.Any]. A dictionary containing additional parameters to spread into the request's header dict

- additional_query_parameters: typing.Dict[str, typing.Any]. A dictionary containing additional parameters to spread into the request's query parameters dict

- additional_body_parameters: typing.Dict[str, typing.Any]. A dictionary containing additional parameters to spread into the request's body parameters dict
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to point to original docs here, in case they change in future releases.

"""

template_fields: Sequence[str] = ("input_text",)
Expand All @@ -49,6 +62,7 @@ def __init__(
conn_id: str = CohereHook.default_conn_name,
timeout: int | None = None,
max_retries: int | None = None,
request_options: RequestOptions | None = None,
**kwargs: Any,
):
super().__init__(**kwargs)
Expand All @@ -58,11 +72,17 @@ def __init__(
self.input_text = input_text
self.timeout = timeout
self.max_retries = max_retries
self.request_options = request_options

@cached_property
def hook(self) -> CohereHook:
"""Return an instance of the CohereHook."""
return CohereHook(conn_id=self.conn_id, timeout=self.timeout, max_retries=self.max_retries)
return CohereHook(
conn_id=self.conn_id,
timeout=self.timeout,
max_retries=self.max_retries,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we use max_retries=self.max_retries in the same PR we are deprecating it? :)

request_options=self.request_options,
)

def execute(self, context: Context) -> list[list[float]]:
"""Embed texts using Cohere embed services."""
Expand Down
3 changes: 2 additions & 1 deletion airflow/providers/cohere/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ integrations:

dependencies:
- apache-airflow>=2.6.0
- cohere>=4.37,<5
- cohere>=5.0.0
- opentelemetry-proto==1.24.0

hooks:
- integration-name: Cohere
Expand Down
3 changes: 2 additions & 1 deletion generated/provider_dependencies.json
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,8 @@
"cohere": {
"deps": [
"apache-airflow>=2.6.0",
"cohere>=4.37,<5"
"cohere>=5.0.0",
"opentelemetry-proto==1.24.0"
],
"devel-deps": [],
"cross-providers-deps": [],
Expand Down
8 changes: 3 additions & 5 deletions tests/providers/cohere/hooks/test_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,14 @@ class TestCohereHook:

def test__get_api_key(self):
api_key = "test"
api_url = "http://some_host.com"
base_url = "http://some_host.com"
timeout = 150
max_retries = 5
with patch.object(
CohereHook,
"get_connection",
return_value=Connection(conn_type="cohere", password=api_key, host=api_url),
return_value=Connection(conn_type="cohere", password=api_key, host=base_url),
), patch("cohere.Client") as client:
hook = CohereHook(timeout=timeout, max_retries=max_retries)
_ = hook.get_conn
client.assert_called_once_with(
api_key=api_key, timeout=timeout, max_retries=max_retries, api_url=api_url
)
client.assert_called_once_with(api_key=api_key, timeout=timeout, base_url=base_url)
10 changes: 4 additions & 6 deletions tests/providers/cohere/operators/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,18 @@ def test_cohere_embedding_operator(cohere_client, get_connection):
Test Cohere client is getting called with the correct key and that
the execute methods returns expected response.
"""
embedded_obj = [1, 2, 3]
embedded_obj = [[1.0, 2.0, 3.0]]

class resp:
embeddings = embedded_obj

api_key = "test"
api_url = "http://some_host.com"
base_url = "http://some_host.com"
timeout = 150
max_retries = 5
texts = ["On Kernel-Target Alignment. We describe a family of global optimization procedures"]

get_connection.return_value = Connection(conn_type="cohere", password=api_key, host=api_url)
get_connection.return_value = Connection(conn_type="cohere", password=api_key, host=base_url)
client_obj = MagicMock()
cohere_client.return_value = client_obj
client_obj.embed.return_value = resp
Expand All @@ -50,7 +50,5 @@ class resp:
)

val = op.execute(context={})
cohere_client.assert_called_once_with(
api_key=api_key, api_url=api_url, timeout=timeout, max_retries=max_retries
)
cohere_client.assert_called_once_with(api_key=api_key, base_url=base_url, timeout=timeout)
assert val == embedded_obj
Loading