-
Notifications
You must be signed in to change notification settings - Fork 14.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
cohere 5.0.0 support #38465
cohere 5.0.0 support #38465
Changes from all commits
e5b0b4e
c70c8e4
c63ebe6
f8a2f51
c3c63b2
f801132
e7cbb96
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,11 +17,13 @@ | |
# under the License. | ||
from __future__ import annotations | ||
|
||
import warnings | ||
from functools import cached_property | ||
from typing import Any | ||
|
||
import cohere | ||
|
||
from airflow.exceptions import AirflowProviderDeprecationWarning | ||
from airflow.hooks.base import BaseHook | ||
|
||
|
||
|
@@ -34,6 +36,17 @@ class CohereHook(BaseHook): | |
:param conn_id: :ref:`Cohere connection id <howto/connection:cohere>` | ||
:param timeout: Request timeout in seconds. | ||
:param max_retries: Maximal number of retries for requests. | ||
:param request_options: Request-specific configuration. | ||
Fields: | ||
- timeout_in_seconds: int. The number of seconds to await an API call before timing out. | ||
|
||
- max_retries: int. The max number of retries to attempt if the API call fails. | ||
|
||
- additional_headers: typing.Dict[str, typing.Any]. A dictionary containing additional parameters to spread into the request's header dict | ||
|
||
- additional_query_parameters: typing.Dict[str, typing.Any]. A dictionary containing additional parameters to spread into the request's query parameters dict | ||
|
||
- additional_body_parameters: typing.Dict[str, typing.Any]. A dictionary containing additional parameters to spread into the request's body parameters dict | ||
""" | ||
|
||
conn_name_attr = "conn_id" | ||
|
@@ -46,23 +59,34 @@ def __init__( | |
conn_id: str = default_conn_name, | ||
timeout: int | None = None, | ||
max_retries: int | None = None, | ||
request_options: dict | None = None, | ||
) -> None: | ||
super().__init__() | ||
self.conn_id = conn_id | ||
self.timeout = timeout | ||
self.max_retries = max_retries | ||
self.request_options = request_options | ||
if self.max_retries: | ||
warnings.warn( | ||
"Argument `max_retries` is deprecated. Please use `request_options` dict for function-specific request configuration instead.", | ||
AirflowProviderDeprecationWarning, | ||
stacklevel=2, | ||
) | ||
self.request_options = ( | ||
{"max_retries": self.max_retries} | ||
if self.request_options is None | ||
else self.request_options.update({"max_retries": self.max_retries}) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
) | ||
|
||
@cached_property | ||
def get_conn(self) -> cohere.Client: # type: ignore[override] | ||
conn = self.get_connection(self.conn_id) | ||
return cohere.Client( | ||
api_key=conn.password, timeout=self.timeout, max_retries=self.max_retries, api_url=conn.host | ||
) | ||
return cohere.Client(api_key=conn.password, timeout=self.timeout, base_url=conn.host) | ||
|
||
def create_embeddings( | ||
self, texts: list[str], model: str = "embed-multilingual-v2.0" | ||
) -> list[list[float]]: | ||
response = self.get_conn.embed(texts=texts, model=model) | ||
) -> list[list[float]] | cohere.EmbedByTypeResponseEmbeddings: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think return type should be just |
||
response = self.get_conn.embed(texts=texts, model=model, request_options=self.request_options) | ||
embeddings = response.embeddings | ||
return embeddings | ||
|
||
|
@@ -76,8 +100,10 @@ def get_ui_field_behaviour(cls) -> dict[str, Any]: | |
} | ||
|
||
def test_connection(self) -> tuple[bool, str]: | ||
if self.max_retries: | ||
self.request_options.update(max_retries=self.max_retries) | ||
try: | ||
self.get_conn.generate("Test", max_tokens=10) | ||
self.get_conn.generate(prompt="Test", max_tokens=10, request_options=self.request_options) | ||
return True, "Connection established" | ||
except Exception as e: | ||
return False, str(e) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,6 +24,8 @@ | |
from airflow.providers.cohere.hooks.cohere import CohereHook | ||
|
||
if TYPE_CHECKING: | ||
from cohere.core.request_options import RequestOptions | ||
|
||
from airflow.utils.context import Context | ||
|
||
|
||
|
@@ -39,6 +41,17 @@ class CohereEmbeddingOperator(BaseOperator): | |
information for Cohere. Defaults to "cohere_default". | ||
:param timeout: Timeout in seconds for Cohere API. | ||
:param max_retries: Number of times to retry before failing. | ||
:param request_options: Request-specific configuration. | ||
Fields: | ||
- timeout_in_seconds: int. The number of seconds to await an API call before timing out. | ||
|
||
- max_retries: int. The max number of retries to attempt if the API call fails. | ||
|
||
- additional_headers: typing.Dict[str, typing.Any]. A dictionary containing additional parameters to spread into the request's header dict | ||
|
||
- additional_query_parameters: typing.Dict[str, typing.Any]. A dictionary containing additional parameters to spread into the request's query parameters dict | ||
|
||
- additional_body_parameters: typing.Dict[str, typing.Any]. A dictionary containing additional parameters to spread into the request's body parameters dict | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be nice to point to original docs here, in case they change in future releases. |
||
""" | ||
|
||
template_fields: Sequence[str] = ("input_text",) | ||
|
@@ -49,6 +62,7 @@ def __init__( | |
conn_id: str = CohereHook.default_conn_name, | ||
timeout: int | None = None, | ||
max_retries: int | None = None, | ||
request_options: RequestOptions | None = None, | ||
**kwargs: Any, | ||
): | ||
super().__init__(**kwargs) | ||
|
@@ -58,11 +72,17 @@ def __init__( | |
self.input_text = input_text | ||
self.timeout = timeout | ||
self.max_retries = max_retries | ||
self.request_options = request_options | ||
|
||
@cached_property | ||
def hook(self) -> CohereHook: | ||
"""Return an instance of the CohereHook.""" | ||
return CohereHook(conn_id=self.conn_id, timeout=self.timeout, max_retries=self.max_retries) | ||
return CohereHook( | ||
conn_id=self.conn_id, | ||
timeout=self.timeout, | ||
max_retries=self.max_retries, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we use |
||
request_options=self.request_options, | ||
) | ||
|
||
def execute(self, context: Context) -> list[list[float]]: | ||
"""Embed texts using Cohere embed services.""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: we generally use if/else in this manner to conditionally assign values, which is the intent of
if
clause here but inelse
you are updating theself.request_options
dict. Which to me seems out of place. It would be much better to have a simplified version of the code here.