diff --git a/libs/partners/together/langchain_together/chat_models.py b/libs/partners/together/langchain_together/chat_models.py index 41ba4604e2caa..7dff1ca948ede 100644 --- a/libs/partners/together/langchain_together/chat_models.py +++ b/libs/partners/together/langchain_together/chat_models.py @@ -59,7 +59,7 @@ def _llm_type(self) -> str: together_api_key: Optional[SecretStr] = Field(default=None, alias="api_key") """Automatically inferred from env are `TOGETHER_API_KEY` if not provided.""" together_api_base: Optional[str] = Field( - default="https://api.together.ai/v1/chat/completions", alias="base_url" + default="https://api.together.ai/v1/", alias="base_url" ) @root_validator() diff --git a/libs/partners/together/langchain_together/embeddings.py b/libs/partners/together/langchain_together/embeddings.py index 095892d660fda..b3dfb337b00a1 100644 --- a/libs/partners/together/langchain_together/embeddings.py +++ b/libs/partners/together/langchain_together/embeddings.py @@ -51,7 +51,7 @@ class TogetherEmbeddings(BaseModel, Embeddings): client: Any = Field(default=None, exclude=True) #: :meta private: async_client: Any = Field(default=None, exclude=True) #: :meta private: model: str = "togethercomputer/m2-bert-80M-8k-retrieval" - """Embeddings model name to use. Do not add suffixes like `-query` and `-passage`. + """Embeddings model name to use. Instead, use 'togethercomputer/m2-bert-80M-8k-retrieval' for example. """ dimensions: Optional[int] = None @@ -62,7 +62,7 @@ class TogetherEmbeddings(BaseModel, Embeddings): together_api_key: Optional[SecretStr] = Field(default=None, alias="api_key") """API Key for Solar API.""" together_api_base: str = Field( - default="https://api.together.ai/v1/embeddings", alias="base_url" + default="https://api.together.ai/v1/", alias="base_url" ) """Endpoint URL to use.""" embedding_ctx_length: int = 4096 @@ -166,12 +166,18 @@ def validate_environment(cls, values: Dict) -> Dict: "default_query": values["default_query"], } if not values.get("client"): - sync_specific = {"http_client": values["http_client"]} + sync_specific = ( + {"http_client": values["http_client"]} if values["http_client"] else {} + ) values["client"] = openai.OpenAI( **client_params, **sync_specific ).embeddings if not values.get("async_client"): - async_specific = {"http_client": values["http_async_client"]} + async_specific = ( + {"http_client": values["http_async_client"]} + if values["http_async_client"] + else {} + ) values["async_client"] = openai.AsyncOpenAI( **client_params, **async_specific ).embeddings @@ -179,8 +185,6 @@ def validate_environment(cls, values: Dict) -> Dict: @property def _invocation_params(self) -> Dict[str, Any]: - self.model = self.model.replace("-query", "").replace("-passage", "") - params: Dict = {"model": self.model, **self.model_kwargs} if self.dimensions is not None: params["dimensions"] = self.dimensions @@ -197,7 +201,7 @@ def embed_documents(self, texts: List[str]) -> List[List[float]]: """ embeddings = [] params = self._invocation_params - params["model"] = params["model"] + "-passage" + params["model"] = params["model"] for text in texts: response = self.client.create(input=text, **params) @@ -217,7 +221,7 @@ def embed_query(self, text: str) -> List[float]: Embedding for the text. """ params = self._invocation_params - params["model"] = params["model"] + "-query" + params["model"] = params["model"] response = self.client.create(input=text, **params) @@ -236,7 +240,7 @@ async def aembed_documents(self, texts: List[str]) -> List[List[float]]: """ embeddings = [] params = self._invocation_params - params["model"] = params["model"] + "-passage" + params["model"] = params["model"] for text in texts: response = await self.async_client.create(input=text, **params) @@ -256,7 +260,7 @@ async def aembed_query(self, text: str) -> List[float]: Embedding for the text. """ params = self._invocation_params - params["model"] = params["model"] + "-query" + params["model"] = params["model"] response = await self.async_client.create(input=text, **params) diff --git a/libs/partners/together/tests/integration_tests/test_chat_models_standard.py b/libs/partners/together/tests/integration_tests/test_chat_models_standard.py index 74f6cef8e151d..5c51ffe8feff4 100644 --- a/libs/partners/together/tests/integration_tests/test_chat_models_standard.py +++ b/libs/partners/together/tests/integration_tests/test_chat_models_standard.py @@ -17,5 +17,5 @@ def chat_model_class(self) -> Type[BaseChatModel]: @pytest.fixture def chat_model_params(self) -> dict: return { - "model": "meta-llama/Llama-3-8b-chat-hf", + "model": "mistralai/Mistral-7B-Instruct-v0.1", }