-
Notifications
You must be signed in to change notification settings - Fork 16.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
community[minor]: Implement ZhipuAIEmbeddings interface (#22821)
- **Description:** Implement ZhipuAIEmbeddings interface, include: - The `embed_query` method - The `embed_documents` method refer to [ZhipuAI Embedding-2](https://open.bigmodel.cn/dev/api#text_embedding) --------- Co-authored-by: Eugene Yurtsev <[email protected]>
- Loading branch information
Showing
4 changed files
with
101 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
from typing import Any, Dict, List | ||
|
||
from langchain_core.embeddings import Embeddings | ||
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator | ||
from langchain_core.utils import get_from_dict_or_env | ||
|
||
|
||
class ZhipuAIEmbeddings(BaseModel, Embeddings): | ||
"""ZhipuAI embedding models. | ||
To use, you should have the ``zhipuai`` python package installed, and the | ||
environment variable ``ZHIPU_API_KEY`` set with your API key or pass it | ||
as a named parameter to the constructor. | ||
More instructions about ZhipuAi Embeddings, you can get it | ||
from https://open.bigmodel.cn/dev/api#vector | ||
Example: | ||
.. code-block:: python | ||
from langchain_community.embeddings import ZhipuAIEmbeddings | ||
embeddings = ZhipuAIEmbeddings(api_key="your-api-key") | ||
text = "This is a test query." | ||
query_result = embeddings.embed_query(text) | ||
# texts = ["This is a test query1.", "This is a test query2."] | ||
# query_result = embeddings.embed_query(texts) | ||
""" | ||
|
||
_client: Any = Field(default=None, exclude=True) #: :meta private: | ||
model: str = Field(default="embedding-2") | ||
"""Model name""" | ||
api_key: str | ||
"""Automatically inferred from env var `ZHIPU_API_KEY` if not provided.""" | ||
|
||
@root_validator(pre=True) | ||
def validate_environment(cls, values: Dict) -> Dict: | ||
"""Validate that auth token exists in environment.""" | ||
values["api_key"] = get_from_dict_or_env(values, "api_key", "ZHIPUAI_API_KEY") | ||
try: | ||
from zhipuai import ZhipuAI | ||
|
||
values["_client"] = ZhipuAI(api_key=values["api_key"]) | ||
except ImportError: | ||
raise ImportError( | ||
"Could not import zhipuai python package." | ||
"Please install it with `pip install zhipuai`." | ||
) | ||
return values | ||
|
||
def embed_query(self, text: str) -> List[float]: | ||
""" | ||
Embeds a text using the AutoVOT algorithm. | ||
Args: | ||
text: A text to embed. | ||
Returns: | ||
Input document's embedded list. | ||
""" | ||
resp = self.embed_documents([text]) | ||
return resp[0] | ||
|
||
def embed_documents(self, texts: List[str]) -> List[List[float]]: | ||
""" | ||
Embeds a list of text documents using the AutoVOT algorithm. | ||
Args: | ||
texts: A list of text documents to embed. | ||
Returns: | ||
A list of embeddings for each document in the input list. | ||
Each embedding is represented as a list of float values. | ||
""" | ||
resp = self._client.embeddings.create(model=self.model, input=texts) | ||
embeddings = [r.embedding for r in resp.data] | ||
return embeddings |
19 changes: 19 additions & 0 deletions
19
libs/community/tests/integration_tests/embeddings/test_zhipuai.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
"""Test ZhipuAI Text Embedding.""" | ||
from langchain_community.embeddings.zhipuai import ZhipuAIEmbeddings | ||
|
||
|
||
def test_zhipuai_embedding_documents() -> None: | ||
"""Test ZhipuAI Text Embedding for documents.""" | ||
documents = ["This is a test query1.", "This is a test query2."] | ||
embedding = ZhipuAIEmbeddings() # type: ignore[call-arg] | ||
res = embedding.embed_documents(documents) | ||
assert len(res) == 2 # type: ignore[arg-type] | ||
assert len(res[0]) == 1024 # type: ignore[index] | ||
|
||
|
||
def test_zhipuai_embedding_query() -> None: | ||
"""Test ZhipuAI Text Embedding for query.""" | ||
document = "This is a test query." | ||
embedding = ZhipuAIEmbeddings() # type: ignore[call-arg] | ||
res = embedding.embed_query(document) | ||
assert len(res) == 1024 # type: ignore[arg-type] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -78,6 +78,7 @@ | |
"OpenVINOEmbeddings", | ||
"OpenVINOBgeEmbeddings", | ||
"SolarEmbeddings", | ||
"ZhipuAIEmbeddings", | ||
] | ||
|
||
|
||
|