Skip to content

Commit

Permalink
community[minor]: Implement ZhipuAIEmbeddings interface (#22821)
Browse files Browse the repository at this point in the history
- **Description:** Implement ZhipuAIEmbeddings interface, include:
     - The `embed_query` method
     - The `embed_documents` method

refer to [ZhipuAI
Embedding-2](https://open.bigmodel.cn/dev/api#text_embedding)

---------

Co-authored-by: Eugene Yurtsev <[email protected]>
  • Loading branch information
2 people authored and hinthornw committed Jun 20, 2024
1 parent b8d3050 commit 25c4336
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 0 deletions.
5 changes: 5 additions & 0 deletions libs/community/langchain_community/embeddings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,9 @@
from langchain_community.embeddings.yandex import (
YandexGPTEmbeddings,
)
from langchain_community.embeddings.zhipuai import (
ZhipuAIEmbeddings,
)

__all__ = [
"AlephAlphaAsymmetricSemanticEmbedding",
Expand Down Expand Up @@ -307,6 +310,7 @@
"VoyageEmbeddings",
"XinferenceEmbeddings",
"YandexGPTEmbeddings",
"ZhipuAIEmbeddings",
]

_module_lookup = {
Expand Down Expand Up @@ -387,6 +391,7 @@
"TitanTakeoffEmbed": "langchain_community.embeddings.titan_takeoff",
"PremAIEmbeddings": "langchain_community.embeddings.premai",
"YandexGPTEmbeddings": "langchain_community.embeddings.yandex",
"ZhipuAIEmbeddings": "langchain_community.embeddings.zhipuai",
}


Expand Down
76 changes: 76 additions & 0 deletions libs/community/langchain_community/embeddings/zhipuai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from typing import Any, Dict, List

from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
from langchain_core.utils import get_from_dict_or_env


class ZhipuAIEmbeddings(BaseModel, Embeddings):
"""ZhipuAI embedding models.
To use, you should have the ``zhipuai`` python package installed, and the
environment variable ``ZHIPU_API_KEY`` set with your API key or pass it
as a named parameter to the constructor.
More instructions about ZhipuAi Embeddings, you can get it
from https://open.bigmodel.cn/dev/api#vector
Example:
.. code-block:: python
from langchain_community.embeddings import ZhipuAIEmbeddings
embeddings = ZhipuAIEmbeddings(api_key="your-api-key")
text = "This is a test query."
query_result = embeddings.embed_query(text)
# texts = ["This is a test query1.", "This is a test query2."]
# query_result = embeddings.embed_query(texts)
"""

_client: Any = Field(default=None, exclude=True) #: :meta private:
model: str = Field(default="embedding-2")
"""Model name"""
api_key: str
"""Automatically inferred from env var `ZHIPU_API_KEY` if not provided."""

@root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that auth token exists in environment."""
values["api_key"] = get_from_dict_or_env(values, "api_key", "ZHIPUAI_API_KEY")
try:
from zhipuai import ZhipuAI

values["_client"] = ZhipuAI(api_key=values["api_key"])
except ImportError:
raise ImportError(
"Could not import zhipuai python package."
"Please install it with `pip install zhipuai`."
)
return values

def embed_query(self, text: str) -> List[float]:
"""
Embeds a text using the AutoVOT algorithm.
Args:
text: A text to embed.
Returns:
Input document's embedded list.
"""
resp = self.embed_documents([text])
return resp[0]

def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""
Embeds a list of text documents using the AutoVOT algorithm.
Args:
texts: A list of text documents to embed.
Returns:
A list of embeddings for each document in the input list.
Each embedding is represented as a list of float values.
"""
resp = self._client.embeddings.create(model=self.model, input=texts)
embeddings = [r.embedding for r in resp.data]
return embeddings
19 changes: 19 additions & 0 deletions libs/community/tests/integration_tests/embeddings/test_zhipuai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
"""Test ZhipuAI Text Embedding."""
from langchain_community.embeddings.zhipuai import ZhipuAIEmbeddings


def test_zhipuai_embedding_documents() -> None:
"""Test ZhipuAI Text Embedding for documents."""
documents = ["This is a test query1.", "This is a test query2."]
embedding = ZhipuAIEmbeddings() # type: ignore[call-arg]
res = embedding.embed_documents(documents)
assert len(res) == 2 # type: ignore[arg-type]
assert len(res[0]) == 1024 # type: ignore[index]


def test_zhipuai_embedding_query() -> None:
"""Test ZhipuAI Text Embedding for query."""
document = "This is a test query."
embedding = ZhipuAIEmbeddings() # type: ignore[call-arg]
res = embedding.embed_query(document)
assert len(res) == 1024 # type: ignore[arg-type]
1 change: 1 addition & 0 deletions libs/community/tests/unit_tests/embeddings/test_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
"OpenVINOEmbeddings",
"OpenVINOBgeEmbeddings",
"SolarEmbeddings",
"ZhipuAIEmbeddings",
]


Expand Down

0 comments on commit 25c4336

Please sign in to comment.