From 25c4336740a97b51756d40d8a1b63392910e89ca Mon Sep 17 00:00:00 2001 From: maang-h <55082429+maang-h@users.noreply.github.com> Date: Fri, 14 Jun 2024 10:45:11 +0800 Subject: [PATCH] community[minor]: Implement ZhipuAIEmbeddings interface (#22821) - **Description:** Implement ZhipuAIEmbeddings interface, include: - The `embed_query` method - The `embed_documents` method refer to [ZhipuAI Embedding-2](https://open.bigmodel.cn/dev/api#text_embedding) --------- Co-authored-by: Eugene Yurtsev --- .../embeddings/__init__.py | 5 ++ .../langchain_community/embeddings/zhipuai.py | 76 +++++++++++++++++++ .../embeddings/test_zhipuai.py | 19 +++++ .../unit_tests/embeddings/test_imports.py | 1 + 4 files changed, 101 insertions(+) create mode 100644 libs/community/langchain_community/embeddings/zhipuai.py create mode 100644 libs/community/tests/integration_tests/embeddings/test_zhipuai.py diff --git a/libs/community/langchain_community/embeddings/__init__.py b/libs/community/langchain_community/embeddings/__init__.py index bef8ec4b1ec99..5b49744a49475 100644 --- a/libs/community/langchain_community/embeddings/__init__.py +++ b/libs/community/langchain_community/embeddings/__init__.py @@ -228,6 +228,9 @@ from langchain_community.embeddings.yandex import ( YandexGPTEmbeddings, ) + from langchain_community.embeddings.zhipuai import ( + ZhipuAIEmbeddings, + ) __all__ = [ "AlephAlphaAsymmetricSemanticEmbedding", @@ -307,6 +310,7 @@ "VoyageEmbeddings", "XinferenceEmbeddings", "YandexGPTEmbeddings", + "ZhipuAIEmbeddings", ] _module_lookup = { @@ -387,6 +391,7 @@ "TitanTakeoffEmbed": "langchain_community.embeddings.titan_takeoff", "PremAIEmbeddings": "langchain_community.embeddings.premai", "YandexGPTEmbeddings": "langchain_community.embeddings.yandex", + "ZhipuAIEmbeddings": "langchain_community.embeddings.zhipuai", } diff --git a/libs/community/langchain_community/embeddings/zhipuai.py b/libs/community/langchain_community/embeddings/zhipuai.py new file mode 100644 index 0000000000000..31f973a735cec --- /dev/null +++ b/libs/community/langchain_community/embeddings/zhipuai.py @@ -0,0 +1,76 @@ +from typing import Any, Dict, List + +from langchain_core.embeddings import Embeddings +from langchain_core.pydantic_v1 import BaseModel, Field, root_validator +from langchain_core.utils import get_from_dict_or_env + + +class ZhipuAIEmbeddings(BaseModel, Embeddings): + """ZhipuAI embedding models. + + To use, you should have the ``zhipuai`` python package installed, and the + environment variable ``ZHIPU_API_KEY`` set with your API key or pass it + as a named parameter to the constructor. + + More instructions about ZhipuAi Embeddings, you can get it + from https://open.bigmodel.cn/dev/api#vector + + Example: + .. code-block:: python + + from langchain_community.embeddings import ZhipuAIEmbeddings + embeddings = ZhipuAIEmbeddings(api_key="your-api-key") + text = "This is a test query." + query_result = embeddings.embed_query(text) + # texts = ["This is a test query1.", "This is a test query2."] + # query_result = embeddings.embed_query(texts) + """ + + _client: Any = Field(default=None, exclude=True) #: :meta private: + model: str = Field(default="embedding-2") + """Model name""" + api_key: str + """Automatically inferred from env var `ZHIPU_API_KEY` if not provided.""" + + @root_validator(pre=True) + def validate_environment(cls, values: Dict) -> Dict: + """Validate that auth token exists in environment.""" + values["api_key"] = get_from_dict_or_env(values, "api_key", "ZHIPUAI_API_KEY") + try: + from zhipuai import ZhipuAI + + values["_client"] = ZhipuAI(api_key=values["api_key"]) + except ImportError: + raise ImportError( + "Could not import zhipuai python package." + "Please install it with `pip install zhipuai`." + ) + return values + + def embed_query(self, text: str) -> List[float]: + """ + Embeds a text using the AutoVOT algorithm. + + Args: + text: A text to embed. + + Returns: + Input document's embedded list. + """ + resp = self.embed_documents([text]) + return resp[0] + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """ + Embeds a list of text documents using the AutoVOT algorithm. + + Args: + texts: A list of text documents to embed. + + Returns: + A list of embeddings for each document in the input list. + Each embedding is represented as a list of float values. + """ + resp = self._client.embeddings.create(model=self.model, input=texts) + embeddings = [r.embedding for r in resp.data] + return embeddings diff --git a/libs/community/tests/integration_tests/embeddings/test_zhipuai.py b/libs/community/tests/integration_tests/embeddings/test_zhipuai.py new file mode 100644 index 0000000000000..57ce6c19c9cd4 --- /dev/null +++ b/libs/community/tests/integration_tests/embeddings/test_zhipuai.py @@ -0,0 +1,19 @@ +"""Test ZhipuAI Text Embedding.""" +from langchain_community.embeddings.zhipuai import ZhipuAIEmbeddings + + +def test_zhipuai_embedding_documents() -> None: + """Test ZhipuAI Text Embedding for documents.""" + documents = ["This is a test query1.", "This is a test query2."] + embedding = ZhipuAIEmbeddings() # type: ignore[call-arg] + res = embedding.embed_documents(documents) + assert len(res) == 2 # type: ignore[arg-type] + assert len(res[0]) == 1024 # type: ignore[index] + + +def test_zhipuai_embedding_query() -> None: + """Test ZhipuAI Text Embedding for query.""" + document = "This is a test query." + embedding = ZhipuAIEmbeddings() # type: ignore[call-arg] + res = embedding.embed_query(document) + assert len(res) == 1024 # type: ignore[arg-type] diff --git a/libs/community/tests/unit_tests/embeddings/test_imports.py b/libs/community/tests/unit_tests/embeddings/test_imports.py index 85b8a37e21d08..7f991488f3b6f 100644 --- a/libs/community/tests/unit_tests/embeddings/test_imports.py +++ b/libs/community/tests/unit_tests/embeddings/test_imports.py @@ -78,6 +78,7 @@ "OpenVINOEmbeddings", "OpenVINOBgeEmbeddings", "SolarEmbeddings", + "ZhipuAIEmbeddings", ]