diff --git a/cookbook/rewrite.ipynb b/cookbook/rewrite.ipynb index 270d7d964edd5..12f5a9e734a90 100644 --- a/cookbook/rewrite.ipynb +++ b/cookbook/rewrite.ipynb @@ -245,7 +245,7 @@ "\n", "\n", "def _parse(text):\n", - " return text.strip(\"**\")" + " return text.strip('\"').strip(\"**\")" ] }, { diff --git a/docs/docs/integrations/retrievers/jina-reranker.ipynb b/docs/docs/integrations/retrievers/jina-reranker.ipynb index 86299efa60265..7a235d6934a07 100644 --- a/docs/docs/integrations/retrievers/jina-reranker.ipynb +++ b/docs/docs/integrations/retrievers/jina-reranker.ipynb @@ -41,6 +41,7 @@ "source": [ "# Helper function for printing docs\n", "\n", + "\n", "def pretty_print_docs(docs):\n", " print(\n", " f\"\\n{'-' * 100}\\n\".join(\n", @@ -125,9 +126,7 @@ "text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)\n", "texts = text_splitter.split_documents(documents)\n", "\n", - "embedding = JinaEmbeddings(\n", - " model_name=\"jina-embeddings-v2-base-en\"\n", - ")\n", + "embedding = JinaEmbeddings(model_name=\"jina-embeddings-v2-base-en\")\n", "retriever = FAISS.from_documents(texts, embedding).as_retriever(search_kwargs={\"k\": 20})\n", "\n", "query = \"What did the president say about Ketanji Brown Jackson\"\n", diff --git a/docs/docs/integrations/toolkits/pandas.ipynb b/docs/docs/integrations/toolkits/pandas.ipynb index 5d4cf3e6fc04f..5bf09f5b2c9d6 100644 --- a/docs/docs/integrations/toolkits/pandas.ipynb +++ b/docs/docs/integrations/toolkits/pandas.ipynb @@ -34,7 +34,9 @@ "import pandas as pd\n", "from langchain_openai import OpenAI\n", "\n", - "df = pd.read_csv(\"titanic.csv\")" + "df = pd.read_csv(\n", + " \"https://raw.githubusercontent.com/pandas-dev/pandas/main/doc/data/titanic.csv\"\n", + ")" ] }, { @@ -116,7 +118,7 @@ } ], "source": [ - "agent.run(\"how many rows are there?\")" + "agent.invoke(\"how many rows are there?\")" ] }, { @@ -154,7 +156,7 @@ } ], "source": [ - "agent.run(\"how many people have more than 3 siblings\")" + "agent.invoke(\"how many people have more than 3 siblings\")" ] }, { @@ -204,7 +206,7 @@ } ], "source": [ - "agent.run(\"whats the square root of the average age?\")" + "agent.invoke(\"whats the square root of the average age?\")" ] }, { @@ -264,7 +266,7 @@ ], "source": [ "agent = create_pandas_dataframe_agent(OpenAI(temperature=0), [df, df1], verbose=True)\n", - "agent.run(\"how many rows in the age column are different?\")" + "agent.invoke(\"how many rows in the age column are different?\")" ] }, { diff --git a/libs/cli/langchain_cli/utils/git.py b/libs/cli/langchain_cli/utils/git.py index 663e2773354ef..e7e4fe8641511 100644 --- a/libs/cli/langchain_cli/utils/git.py +++ b/libs/cli/langchain_cli/utils/git.py @@ -155,7 +155,7 @@ def _get_repo_path(gitstring: str, ref: Optional[str], repo_dir: Path) -> Path: removed_protocol = gitstring.split("://")[-1] removed_basename = re.split(r"[/:]", removed_protocol, 1)[-1] removed_extras = removed_basename.split("#")[0] - foldername = re.sub(r"[^a-zA-Z0-9_]", "_", removed_extras) + foldername = re.sub(r"\W", "_", removed_extras) directory_name = f"{foldername}_{hashed}" return repo_dir / directory_name diff --git a/libs/community/langchain_community/chat_models/cohere.py b/libs/community/langchain_community/chat_models/cohere.py index e3f20ad9c435f..7682f1d4ca39e 100644 --- a/libs/community/langchain_community/chat_models/cohere.py +++ b/libs/community/langchain_community/chat_models/cohere.py @@ -244,4 +244,4 @@ async def _agenerate( def get_num_tokens(self, text: str) -> int: """Calculate number of tokens.""" - return len(self.client.tokenize(text).tokens) + return len(self.client.tokenize(text=text).tokens) diff --git a/libs/community/langchain_community/document_loaders/recursive_url_loader.py b/libs/community/langchain_community/document_loaders/recursive_url_loader.py index 698af55c60698..6231c7af8d665 100644 --- a/libs/community/langchain_community/document_loaders/recursive_url_loader.py +++ b/libs/community/langchain_community/document_loaders/recursive_url_loader.py @@ -94,6 +94,8 @@ def __init__( headers: Optional[dict] = None, check_response_status: bool = False, continue_on_failure: bool = True, + *, + base_url: Optional[str] = None, ) -> None: """Initialize with URL to crawl and any subdirectories to exclude. @@ -120,6 +122,7 @@ def __init__( URLs with error responses (400-599). continue_on_failure: If True, continue if getting or parsing a link raises an exception. Otherwise, raise the exception. + base_url: The base url to check for outside links against. """ self.url = url @@ -146,6 +149,7 @@ def __init__( self.headers = headers self.check_response_status = check_response_status self.continue_on_failure = continue_on_failure + self.base_url = base_url if base_url is not None else url def _get_child_links_recursive( self, url: str, visited: Set[str], *, depth: int = 0 @@ -187,7 +191,7 @@ def _get_child_links_recursive( sub_links = extract_sub_links( response.text, url, - base_url=self.url, + base_url=self.base_url, pattern=self.link_regex, prevent_outside=self.prevent_outside, exclude_prefixes=self.exclude_dirs, @@ -273,7 +277,7 @@ async def _async_get_child_links_recursive( sub_links = extract_sub_links( text, url, - base_url=self.url, + base_url=self.base_url, pattern=self.link_regex, prevent_outside=self.prevent_outside, exclude_prefixes=self.exclude_dirs, diff --git a/libs/community/langchain_community/llms/llamacpp.py b/libs/community/langchain_community/llms/llamacpp.py index 85acfb999e9ca..b06e6d8cf763c 100644 --- a/libs/community/langchain_community/llms/llamacpp.py +++ b/libs/community/langchain_community/llms/llamacpp.py @@ -344,11 +344,11 @@ def _stream( text=part["choices"][0]["text"], generation_info={"logprobs": logprobs}, ) - yield chunk if run_manager: run_manager.on_llm_new_token( token=chunk.text, verbose=self.verbose, log_probs=logprobs ) + yield chunk def get_num_tokens(self, text: str) -> int: tokenized_text = self.client.tokenize(text.encode("utf-8")) diff --git a/libs/community/langchain_community/vectorstores/docarray/hnsw.py b/libs/community/langchain_community/vectorstores/docarray/hnsw.py index 4394847184969..8b33f7a0d58a6 100644 --- a/libs/community/langchain_community/vectorstores/docarray/hnsw.py +++ b/libs/community/langchain_community/vectorstores/docarray/hnsw.py @@ -14,7 +14,7 @@ class DocArrayHnswSearch(DocArrayIndex): """`HnswLib` storage using `DocArray` package. To use it, you should have the ``docarray`` package with version >=0.32.0 installed. - You can install it with `pip install "langchain[docarray]"`. + You can install it with `pip install "docarray[hnswlib]"`. """ @classmethod diff --git a/libs/core/langchain_core/prompts/chat.py b/libs/core/langchain_core/prompts/chat.py index 1c74bdd9c4937..e4be7adf872e7 100644 --- a/libs/core/langchain_core/prompts/chat.py +++ b/libs/core/langchain_core/prompts/chat.py @@ -551,7 +551,10 @@ def pretty_print(self) -> None: MessageLikeRepresentation = Union[ MessageLike, - Tuple[Union[str, Type], Union[str, List[dict], List[object]]], + Tuple[ + Union[str, Type], + Union[str, List[dict], List[object]], + ], str, ] @@ -590,6 +593,45 @@ class ChatPromptTemplate(BaseChatPromptTemplate): # ] #) + Messages Placeholder: + + .. code-block:: python + + # In addition to Human/AI/Tool/Function messages, + # you can initialize the template with a MessagesPlaceholder + # either using the class directly or with the shorthand tuple syntax: + + template = ChatPromptTemplate.from_messages([ + ("system", "You are a helpful AI bot."), + # Means the template will receive an optional list of messages under + # the "conversation" key + ("placeholder", "{conversation}") + # Equivalently: + # MessagesPlaceholder(variable_name="conversation", optional=True) + ]) + + prompt_value = template.invoke( + { + "conversation": [ + ("human", "Hi!"), + ("ai", "How can I assist you today?"), + ("human", "Can you make me an ice cream sundae?"), + ("ai", "No.") + ] + } + ) + + # Output: + # ChatPromptValue( + # messages=[ + # SystemMessage(content='You are a helpful AI bot.'), + # HumanMessage(content='Hi!'), + # AIMessage(content='How can I assist you today?'), + # HumanMessage(content='Can you make me an ice cream sundae?'), + # AIMessage(content='No.'), + # ] + #) + Single-variable template: If your prompt has only a single input variable (i.e., 1 instance of "{variable_nams}"), @@ -949,6 +991,36 @@ def _create_template_from_message_type( message = AIMessagePromptTemplate.from_template(cast(str, template)) elif message_type == "system": message = SystemMessagePromptTemplate.from_template(cast(str, template)) + elif message_type == "placeholder": + if isinstance(template, str): + if template[0] != "{" or template[-1] != "}": + raise ValueError( + f"Invalid placeholder template: {template}." + " Expected a variable name surrounded by curly braces." + ) + var_name = template[1:-1] + message = MessagesPlaceholder(variable_name=var_name, optional=True) + elif len(template) == 2 and isinstance(template[1], bool): + var_name_wrapped, is_optional = template + if not isinstance(var_name_wrapped, str): + raise ValueError( + "Expected variable name to be a string." f" Got: {var_name_wrapped}" + ) + if var_name_wrapped[0] != "{" or var_name_wrapped[-1] != "}": + raise ValueError( + f"Invalid placeholder template: {var_name_wrapped}." + " Expected a variable name surrounded by curly braces." + ) + var_name = var_name_wrapped[1:-1] + + message = MessagesPlaceholder(variable_name=var_name, optional=is_optional) + else: + raise ValueError( + "Unexpected arguments for placeholder message type." + " Expected either a single string variable name" + " or a list of [variable_name: str, is_optional: bool]." + f" Got: {template}" + ) else: raise ValueError( f"Unexpected message type: {message_type}. Use one of 'human'," diff --git a/libs/core/langchain_core/utils/function_calling.py b/libs/core/langchain_core/utils/function_calling.py index 5977615654ce9..a406b87097eb5 100644 --- a/libs/core/langchain_core/utils/function_calling.py +++ b/libs/core/langchain_core/utils/function_calling.py @@ -3,6 +3,7 @@ from __future__ import annotations import inspect +import uuid from typing import ( TYPE_CHECKING, Any, @@ -20,6 +21,12 @@ from typing_extensions import TypedDict from langchain_core._api import deprecated +from langchain_core.messages import ( + AIMessage, + BaseMessage, + HumanMessage, + ToolMessage, +) from langchain_core.pydantic_v1 import BaseModel from langchain_core.utils.json_schema import dereference_refs @@ -332,3 +339,96 @@ def convert_to_openai_tool( return tool function = convert_to_openai_function(tool) return {"type": "function", "function": function} + + +def tool_example_to_messages( + input: str, tool_calls: List[BaseModel], tool_outputs: Optional[List[str]] = None +) -> List[BaseMessage]: + """Convert an example into a list of messages that can be fed into an LLM. + + This code is an adapter that converts a single example to a list of messages + that can be fed into a chat model. + + The list of messages per example corresponds to: + + 1) HumanMessage: contains the content from which content should be extracted. + 2) AIMessage: contains the extracted information from the model + 3) ToolMessage: contains confirmation to the model that the model requested a tool + correctly. + + The ToolMessage is required because some chat models are hyper-optimized for agents + rather than for an extraction use case. + + Arguments: + input: string, the user input + tool_calls: List[BaseModel], a list of tool calls represented as Pydantic + BaseModels + tool_outputs: Optional[List[str]], a list of tool call outputs. + Does not need to be provided. If not provided, a placeholder value + will be inserted. + + Returns: + A list of messages + + Examples: + + .. code-block:: python + + from typing import List, Optional + from langchain_core.pydantic_v1 import BaseModel, Field + from langchain_openai import ChatOpenAI + + class Person(BaseModel): + '''Information about a person.''' + name: Optional[str] = Field(..., description="The name of the person") + hair_color: Optional[str] = Field( + ..., description="The color of the peron's eyes if known" + ) + height_in_meters: Optional[str] = Field( + ..., description="Height in METERs" + ) + + examples = [ + ( + "The ocean is vast and blue. It's more than 20,000 feet deep.", + Person(name=None, height_in_meters=None, hair_color=None), + ), + ( + "Fiona traveled far from France to Spain.", + Person(name="Fiona", height_in_meters=None, hair_color=None), + ), + ] + + + messages = [] + + for txt, tool_call in examples: + messages.extend( + tool_example_to_messages(txt, [tool_call]) + ) + """ + messages: List[BaseMessage] = [HumanMessage(content=input)] + openai_tool_calls = [] + for tool_call in tool_calls: + openai_tool_calls.append( + { + "id": str(uuid.uuid4()), + "type": "function", + "function": { + # The name of the function right now corresponds to the name + # of the pydantic model. This is implicit in the API right now, + # and will be improved over time. + "name": tool_call.__class__.__name__, + "arguments": tool_call.json(), + }, + } + ) + messages.append( + AIMessage(content="", additional_kwargs={"tool_calls": openai_tool_calls}) + ) + tool_outputs = tool_outputs or ["You have correctly called this tool."] * len( + openai_tool_calls + ) + for output, tool_call_dict in zip(tool_outputs, openai_tool_calls): + messages.append(ToolMessage(content=output, tool_call_id=tool_call_dict["id"])) # type: ignore + return messages diff --git a/libs/core/tests/unit_tests/prompts/test_chat.py b/libs/core/tests/unit_tests/prompts/test_chat.py index bcc1633cc61df..305e981dfe9ea 100644 --- a/libs/core/tests/unit_tests/prompts/test_chat.py +++ b/libs/core/tests/unit_tests/prompts/test_chat.py @@ -535,6 +535,25 @@ def test_chat_prompt_message_placeholder_partial() -> None: assert prompt.format_messages() == [SystemMessage(content="foo")] +def test_chat_prompt_message_placeholder_tuple() -> None: + prompt = ChatPromptTemplate.from_messages([("placeholder", "{convo}")]) + assert prompt.format_messages(convo=[("user", "foo")]) == [ + HumanMessage(content="foo") + ] + + assert prompt.format_messages() == [] + + # Is optional = True + optional_prompt = ChatPromptTemplate.from_messages( + [("placeholder", ["{convo}", False])] + ) + assert optional_prompt.format_messages(convo=[("user", "foo")]) == [ + HumanMessage(content="foo") + ] + with pytest.raises(KeyError): + assert optional_prompt.format_messages() == [] + + def test_messages_prompt_accepts_list() -> None: prompt = ChatPromptTemplate.from_messages([MessagesPlaceholder("history")]) value = prompt.invoke([("user", "Hi there")]) # type: ignore diff --git a/libs/core/tests/unit_tests/utils/test_function_calling.py b/libs/core/tests/unit_tests/utils/test_function_calling.py index 629cf769c5587..00328bcf29b44 100644 --- a/libs/core/tests/unit_tests/utils/test_function_calling.py +++ b/libs/core/tests/unit_tests/utils/test_function_calling.py @@ -2,9 +2,13 @@ import pytest +from langchain_core.messages import AIMessage, HumanMessage, ToolMessage from langchain_core.pydantic_v1 import BaseModel, Field from langchain_core.tools import BaseTool, tool -from langchain_core.utils.function_calling import convert_to_openai_function +from langchain_core.utils.function_calling import ( + convert_to_openai_function, + tool_example_to_messages, +) @pytest.fixture() @@ -109,3 +113,74 @@ def func5( func = convert_to_openai_function(func5) req = func["parameters"]["required"] assert set(req) == {"b"} + + +class FakeCall(BaseModel): + data: str + + +def test_valid_example_conversion() -> None: + expected_messages = [ + HumanMessage(content="This is a valid example"), + AIMessage(content="", additional_kwargs={"tool_calls": []}), + ] + assert ( + tool_example_to_messages(input="This is a valid example", tool_calls=[]) + == expected_messages + ) + + +def test_multiple_tool_calls() -> None: + messages = tool_example_to_messages( + input="This is an example", + tool_calls=[ + FakeCall(data="ToolCall1"), + FakeCall(data="ToolCall2"), + FakeCall(data="ToolCall3"), + ], + ) + assert len(messages) == 5 + assert isinstance(messages[0], HumanMessage) + assert isinstance(messages[1], AIMessage) + assert isinstance(messages[2], ToolMessage) + assert isinstance(messages[3], ToolMessage) + assert isinstance(messages[4], ToolMessage) + assert messages[1].additional_kwargs["tool_calls"] == [ + { + "id": messages[2].tool_call_id, + "type": "function", + "function": {"name": "FakeCall", "arguments": '{"data": "ToolCall1"}'}, + }, + { + "id": messages[3].tool_call_id, + "type": "function", + "function": {"name": "FakeCall", "arguments": '{"data": "ToolCall2"}'}, + }, + { + "id": messages[4].tool_call_id, + "type": "function", + "function": {"name": "FakeCall", "arguments": '{"data": "ToolCall3"}'}, + }, + ] + + +def test_tool_outputs() -> None: + messages = tool_example_to_messages( + input="This is an example", + tool_calls=[ + FakeCall(data="ToolCall1"), + ], + tool_outputs=["Output1"], + ) + assert len(messages) == 3 + assert isinstance(messages[0], HumanMessage) + assert isinstance(messages[1], AIMessage) + assert isinstance(messages[2], ToolMessage) + assert messages[1].additional_kwargs["tool_calls"] == [ + { + "id": messages[2].tool_call_id, + "type": "function", + "function": {"name": "FakeCall", "arguments": '{"data": "ToolCall1"}'}, + }, + ] + assert messages[2].content == "Output1" diff --git a/libs/langchain/langchain/agents/openai_assistant/base.py b/libs/langchain/langchain/agents/openai_assistant/base.py index e7cb100ae5007..76195fec352e0 100644 --- a/libs/langchain/langchain/agents/openai_assistant/base.py +++ b/libs/langchain/langchain/agents/openai_assistant/base.py @@ -3,12 +3,23 @@ import json from json import JSONDecodeError from time import sleep -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple, Union +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Optional, + Sequence, + Tuple, + Type, + Union, +) from langchain_core.agents import AgentAction, AgentFinish from langchain_core.callbacks import CallbackManager from langchain_core.load import dumpd -from langchain_core.pydantic_v1 import Field, root_validator +from langchain_core.pydantic_v1 import BaseModel, Field, root_validator from langchain_core.runnables import RunnableConfig, RunnableSerializable, ensure_config from langchain_core.tools import BaseTool from langchain_core.utils.function_calling import convert_to_openai_tool @@ -76,6 +87,32 @@ def _get_openai_async_client() -> openai.AsyncOpenAI: ) from e +def _is_assistants_builtin_tool( + tool: Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool], +) -> bool: + """Determine if tool corresponds to OpenAI Assistants built-in.""" + assistants_builtin_tools = ("code_interpreter", "retrieval") + return ( + isinstance(tool, dict) + and ("type" in tool) + and (tool["type"] in assistants_builtin_tools) + ) + + +def _get_assistants_tool( + tool: Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool], +) -> Dict[str, Any]: + """Convert a raw function/class to an OpenAI tool. + + Note that OpenAI assistants supports several built-in tools, + such as "code_interpreter" and "retrieval." + """ + if _is_assistants_builtin_tool(tool): + return tool # type: ignore + else: + return convert_to_openai_tool(tool) + + OutputType = Union[ List[OpenAIAssistantAction], OpenAIAssistantFinish, @@ -210,7 +247,7 @@ def create_assistant( assistant = client.beta.assistants.create( name=name, instructions=instructions, - tools=[convert_to_openai_tool(tool) for tool in tools], # type: ignore + tools=[_get_assistants_tool(tool) for tool in tools], # type: ignore model=model, file_ids=kwargs.get("file_ids"), ) @@ -328,7 +365,7 @@ async def acreate_assistant( AsyncOpenAIAssistantRunnable configured to run using the created assistant. """ async_client = async_client or _get_openai_async_client() - openai_tools = [convert_to_openai_tool(tool) for tool in tools] + openai_tools = [_get_assistants_tool(tool) for tool in tools] assistant = await async_client.beta.assistants.create( name=name, instructions=instructions, diff --git a/libs/langchain/langchain/memory/vectorstore.py b/libs/langchain/langchain/memory/vectorstore.py index 3f4430ee7cf9a..b288ef57d8424 100644 --- a/libs/langchain/langchain/memory/vectorstore.py +++ b/libs/langchain/langchain/memory/vectorstore.py @@ -39,13 +39,9 @@ def _get_prompt_input_key(self, inputs: Dict[str, Any]) -> str: return get_prompt_input_key(inputs, self.memory_variables) return self.input_key - def load_memory_variables( - self, inputs: Dict[str, Any] + def _documents_to_memory_variables( + self, docs: List[Document] ) -> Dict[str, Union[List[Document], str]]: - """Return history buffer.""" - input_key = self._get_prompt_input_key(inputs) - query = inputs[input_key] - docs = self.retriever.get_relevant_documents(query) result: Union[List[Document], str] if not self.return_docs: result = "\n".join([doc.page_content for doc in docs]) @@ -53,6 +49,24 @@ def load_memory_variables( result = docs return {self.memory_key: result} + def load_memory_variables( + self, inputs: Dict[str, Any] + ) -> Dict[str, Union[List[Document], str]]: + """Return history buffer.""" + input_key = self._get_prompt_input_key(inputs) + query = inputs[input_key] + docs = self.retriever.get_relevant_documents(query) + return self._documents_to_memory_variables(docs) + + async def aload_memory_variables( + self, inputs: Dict[str, Any] + ) -> Dict[str, Union[List[Document], str]]: + """Return history buffer.""" + input_key = self._get_prompt_input_key(inputs) + query = inputs[input_key] + docs = await self.retriever.aget_relevant_documents(query) + return self._documents_to_memory_variables(docs) + def _form_documents( self, inputs: Dict[str, Any], outputs: Dict[str, str] ) -> List[Document]: @@ -73,5 +87,15 @@ def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: documents = self._form_documents(inputs, outputs) self.retriever.add_documents(documents) + async def asave_context( + self, inputs: Dict[str, Any], outputs: Dict[str, str] + ) -> None: + """Save context from this conversation to buffer.""" + documents = self._form_documents(inputs, outputs) + await self.retriever.aadd_documents(documents) + def clear(self) -> None: """Nothing to clear.""" + + async def aclear(self) -> None: + """Nothing to clear.""" diff --git a/libs/langchain/tests/unit_tests/agents/test_openai_assistant.py b/libs/langchain/tests/unit_tests/agents/test_openai_assistant.py index aaa4ba48d1df0..45fcea4ad2af5 100644 --- a/libs/langchain/tests/unit_tests/agents/test_openai_assistant.py +++ b/libs/langchain/tests/unit_tests/agents/test_openai_assistant.py @@ -1,8 +1,20 @@ +from functools import partial +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + import pytest from langchain.agents.openai_assistant import OpenAIAssistantRunnable +def _create_mock_client(*args: Any, use_async: bool = False, **kwargs: Any) -> Any: + client = AsyncMock() if use_async else MagicMock() + mock_assistant = MagicMock() + mock_assistant.id = "abc123" + client.beta.assistants.create.return_value = mock_assistant # type: ignore + return client + + @pytest.mark.requires("openai") def test_user_supplied_client() -> None: import openai @@ -19,3 +31,34 @@ def test_user_supplied_client() -> None: ) assert assistant.client == client + + +@pytest.mark.requires("openai") +@patch( + "langchain.agents.openai_assistant.base._get_openai_client", + new=partial(_create_mock_client, use_async=False), +) +def test_create_assistant() -> None: + assistant = OpenAIAssistantRunnable.create_assistant( + name="name", + instructions="instructions", + tools=[{"type": "code_interpreter"}], + model="", + ) + assert isinstance(assistant, OpenAIAssistantRunnable) + + +@pytest.mark.requires("openai") +@patch( + "langchain.agents.openai_assistant.base._get_openai_async_client", + new=partial(_create_mock_client, use_async=True), +) +async def test_acreate_assistant() -> None: + assistant = await OpenAIAssistantRunnable.acreate_assistant( + name="name", + instructions="instructions", + tools=[{"type": "code_interpreter"}], + model="", + client=_create_mock_client(), + ) + assert isinstance(assistant, OpenAIAssistantRunnable) diff --git a/libs/partners/mistralai/docs/embeddings.ipynb b/libs/partners/mistralai/docs/embeddings.ipynb deleted file mode 100644 index 33ed1137fdc9f..0000000000000 --- a/libs/partners/mistralai/docs/embeddings.ipynb +++ /dev/null @@ -1,103 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "b14a24db", - "metadata": {}, - "source": [ - "# MistralAIEmbeddings\n", - "\n", - "This notebook explains how to use MistralAIEmbeddings, which is included in the langchain_mistralai package, to embed texts in langchain." - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "0ab948fc", - "metadata": {}, - "outputs": [], - "source": [ - "# pip install -U langchain-mistralai" - ] - }, - { - "cell_type": "markdown", - "id": "67c637ca", - "metadata": {}, - "source": [ - "## import the library" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "5709b030", - "metadata": {}, - "outputs": [], - "source": [ - "from langchain_mistralai import MistralAIEmbeddings" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "1756b1ba", - "metadata": {}, - "outputs": [], - "source": [ - "embedding = MistralAIEmbeddings(mistral_api_key='your-api-key')" - ] - }, - { - "cell_type": "markdown", - "id": "4a2a098d", - "metadata": {}, - "source": [ - "# Using the Embedding Model\n", - "With `MistralAIEmbeddings`, you can directly use the default model 'mistral-embed', or set a different one if available." - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "584b9af5", - "metadata": {}, - "outputs": [], - "source": [ - "embedding.model = 'mistral-embed' # or your preferred model if available" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "be18b873", - "metadata": {}, - "outputs": [], - "source": [ - "res_query = embedding.embed_query(\"The test information\")\n", - "res_document = embedding.embed_documents([\"test1\", \"another test\"])" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.4" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/libs/partners/mistralai/langchain_mistralai/chat_models.py b/libs/partners/mistralai/langchain_mistralai/chat_models.py index 6d9fbbb18d2eb..2566f999cd388 100644 --- a/libs/partners/mistralai/langchain_mistralai/chat_models.py +++ b/libs/partners/mistralai/langchain_mistralai/chat_models.py @@ -1,10 +1,10 @@ from __future__ import annotations -import importlib.util import logging from operator import itemgetter from typing import ( Any, + AsyncContextManager, AsyncIterator, Callable, Dict, @@ -18,6 +18,8 @@ cast, ) +import httpx +from httpx_sse import EventSource, aconnect_sse, connect_sse from langchain_core._api import beta from langchain_core.callbacks import ( AsyncCallbackManagerForLLMRun, @@ -54,19 +56,6 @@ from langchain_core.tools import BaseTool from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env from langchain_core.utils.function_calling import convert_to_openai_tool -from mistralai.async_client import MistralAsyncClient -from mistralai.client import MistralClient -from mistralai.constants import ENDPOINT as DEFAULT_MISTRAL_ENDPOINT -from mistralai.exceptions import ( - MistralAPIException, - MistralConnectionException, - MistralException, -) -from mistralai.models.chat_completion import ( - ChatCompletionResponse as MistralChatCompletionResponse, -) -from mistralai.models.chat_completion import ChatMessage as MistralChatMessage -from mistralai.models.chat_completion import DeltaMessage as MistralDeltaMessage logger = logging.getLogger(__name__) @@ -79,36 +68,34 @@ def _create_retry_decorator( ) -> Callable[[Any], Any]: """Returns a tenacity retry decorator, preconfigured to handle exceptions""" - errors = [ - MistralException, - MistralAPIException, - MistralConnectionException, - ] + errors = [httpx.RequestError, httpx.StreamError] return create_base_retry_decorator( error_types=errors, max_retries=llm.max_retries, run_manager=run_manager ) def _convert_mistral_chat_message_to_message( - _message: MistralChatMessage, + _message: Dict, ) -> BaseMessage: - role = _message.role - content = cast(Union[str, List], _message.content) - if role == "user": - return HumanMessage(content=content) - elif role == "assistant": - additional_kwargs: Dict = {} - if hasattr(_message, "tool_calls") and getattr(_message, "tool_calls"): - additional_kwargs["tool_calls"] = [ - tc.model_dump() for tc in getattr(_message, "tool_calls") - ] - return AIMessage(content=content, additional_kwargs=additional_kwargs) - elif role == "system": - return SystemMessage(content=content) - elif role == "tool": - return ToolMessage(content=content, name=_message.name) # type: ignore[attr-defined] - else: - return ChatMessage(content=content, role=role) + role = _message["role"] + assert role == "assistant", f"Expected role to be 'assistant', got {role}" + content = cast(str, _message["content"]) + + additional_kwargs: Dict = {} + if tool_calls := _message.get("tool_calls"): + additional_kwargs["tool_calls"] = tool_calls + return AIMessage(content=content, additional_kwargs=additional_kwargs) + + +async def _aiter_sse( + event_source_mgr: AsyncContextManager[EventSource], +) -> AsyncIterator[Dict]: + """Iterate over the server-sent events.""" + async with event_source_mgr as event_source: + async for event in event_source.aiter_sse(): + if event.data == "[DONE]": + return + yield event.json() async def acompletion_with_retry( @@ -121,28 +108,33 @@ async def acompletion_with_retry( @retry_decorator async def _completion_with_retry(**kwargs: Any) -> Any: - stream = kwargs.pop("stream", False) + if "stream" not in kwargs: + kwargs["stream"] = False + stream = kwargs["stream"] if stream: - return llm.async_client.chat_stream(**kwargs) + event_source = aconnect_sse( + llm.async_client, "POST", "/chat/completions", json=kwargs + ) + + return _aiter_sse(event_source) else: - return await llm.async_client.chat(**kwargs) + response = await llm.async_client.post(url="/chat/completions", json=kwargs) + return response.json() return await _completion_with_retry(**kwargs) def _convert_delta_to_message_chunk( - _delta: MistralDeltaMessage, default_class: Type[BaseMessageChunk] + _delta: Dict, default_class: Type[BaseMessageChunk] ) -> BaseMessageChunk: - role = getattr(_delta, "role") - content = getattr(_delta, "content", "") + role = _delta.get("role") + content = _delta.get("content", "") if role == "user" or default_class == HumanMessageChunk: return HumanMessageChunk(content=content) elif role == "assistant" or default_class == AIMessageChunk: additional_kwargs: Dict = {} - if hasattr(_delta, "tool_calls") and getattr(_delta, "tool_calls"): - additional_kwargs["tool_calls"] = [ - tc.model_dump() for tc in getattr(_delta, "tool_calls") - ] + if tool_calls := _delta.get("tool_calls"): + additional_kwargs["tool_calls"] = [tc.model_dump() for tc in tool_calls] return AIMessageChunk(content=content, additional_kwargs=additional_kwargs) elif role == "system" or default_class == SystemMessageChunk: return SystemMessageChunk(content=content) @@ -154,44 +146,48 @@ def _convert_delta_to_message_chunk( def _convert_message_to_mistral_chat_message( message: BaseMessage, -) -> MistralChatMessage: +) -> Dict: if isinstance(message, ChatMessage): - mistral_message = MistralChatMessage(role=message.role, content=message.content) + return dict(role=message.role, content=message.content) elif isinstance(message, HumanMessage): - mistral_message = MistralChatMessage(role="user", content=message.content) + return dict(role="user", content=message.content) elif isinstance(message, AIMessage): if "tool_calls" in message.additional_kwargs: - from mistralai.models.chat_completion import ( # type: ignore[attr-defined] - ToolCall as MistralToolCall, - ) - tool_calls = [ - MistralToolCall.model_validate(tc) + { + "function": { + "name": tc["function"]["name"], + "arguments": tc["function"]["arguments"], + } + } for tc in message.additional_kwargs["tool_calls"] ] else: tool_calls = None - mistral_message = MistralChatMessage( - role="assistant", content=message.content, tool_calls=tool_calls - ) + return { + "role": "assistant", + "content": message.content, + "tool_calls": tool_calls, + } elif isinstance(message, SystemMessage): - mistral_message = MistralChatMessage(role="system", content=message.content) + return dict(role="system", content=message.content) elif isinstance(message, ToolMessage): - mistral_message = MistralChatMessage( - role="tool", content=message.content, name=message.name - ) + return { + "role": "tool", + "content": message.content, + "name": message.name, + } else: raise ValueError(f"Got unknown type {message}") - return mistral_message class ChatMistralAI(BaseChatModel): """A chat model that uses the MistralAI API.""" - client: MistralClient = Field(default=None) #: :meta private: - async_client: MistralAsyncClient = Field(default=None) #: :meta private: + client: httpx.Client = Field(default=None) #: :meta private: + async_client: httpx.AsyncClient = Field(default=None) #: :meta private: mistral_api_key: Optional[SecretStr] = None - endpoint: str = DEFAULT_MISTRAL_ENDPOINT + endpoint: str = "https://api.mistral.ai/v1" max_retries: int = 5 timeout: int = 120 max_concurrent_requests: int = 64 @@ -204,6 +200,7 @@ class ChatMistralAI(BaseChatModel): probability sum is at least top_p. Must be in the closed interval [0.0, 1.0].""" random_seed: Optional[int] = None safe_mode: bool = False + streaming: bool = False @property def _default_params(self) -> Dict[str, Any]: @@ -214,7 +211,7 @@ def _default_params(self) -> Dict[str, Any]: "max_tokens": self.max_tokens, "top_p": self.top_p, "random_seed": self.random_seed, - "safe_mode": self.safe_mode, + "safe_prompt": self.safe_mode, } filtered = {k: v for k, v in defaults.items() if v is not None} return filtered @@ -228,45 +225,60 @@ def completion_with_retry( self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any ) -> Any: """Use tenacity to retry the completion call.""" - retry_decorator = _create_retry_decorator(self, run_manager=run_manager) + # retry_decorator = _create_retry_decorator(self, run_manager=run_manager) - @retry_decorator + # @retry_decorator def _completion_with_retry(**kwargs: Any) -> Any: - stream = kwargs.pop("stream", False) + if "stream" not in kwargs: + kwargs["stream"] = False + stream = kwargs["stream"] if stream: - return self.client.chat_stream(**kwargs) + + def iter_sse() -> Iterator[Dict]: + with connect_sse( + self.client, "POST", "/chat/completions", json=kwargs + ) as event_source: + for event in event_source.iter_sse(): + if event.data == "[DONE]": + return + yield event.json() + + return iter_sse() else: - return self.client.chat(**kwargs) + return self.client.post(url="/chat/completions", json=kwargs).json() - return _completion_with_retry(**kwargs) + rtn = _completion_with_retry(**kwargs) + return rtn @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate api key, python package exists, temperature, and top_p.""" - mistralai_spec = importlib.util.find_spec("mistralai") - if mistralai_spec is None: - raise MistralException( - "Could not find mistralai python package. " - "Please install it with `pip install mistralai`" - ) values["mistral_api_key"] = convert_to_secret_str( get_from_dict_or_env( values, "mistral_api_key", "MISTRAL_API_KEY", default="" ) ) - values["client"] = MistralClient( - api_key=values["mistral_api_key"].get_secret_value(), - endpoint=values["endpoint"], - max_retries=values["max_retries"], + api_key_str = values["mistral_api_key"].get_secret_value() + # todo: handle retries + values["client"] = httpx.Client( + base_url=values["endpoint"], + headers={ + "Content-Type": "application/json", + "Accept": "application/json", + "Authorization": f"Bearer {api_key_str}", + }, timeout=values["timeout"], ) - values["async_client"] = MistralAsyncClient( - api_key=values["mistral_api_key"].get_secret_value(), - endpoint=values["endpoint"], - max_retries=values["max_retries"], + # todo: handle retries and max_concurrency + values["async_client"] = httpx.AsyncClient( + base_url=values["endpoint"], + headers={ + "Content-Type": "application/json", + "Accept": "application/json", + "Authorization": f"Bearer {api_key_str}", + }, timeout=values["timeout"], - max_concurrent_requests=values["max_concurrent_requests"], ) if values["temperature"] is not None and not 0 <= values["temperature"] <= 1: @@ -285,7 +297,7 @@ def _generate( stream: Optional[bool] = None, **kwargs: Any, ) -> ChatResult: - should_stream = stream if stream is not None else False + should_stream = stream if stream is not None else self.streaming if should_stream: stream_iter = self._stream( messages, stop=stop, run_manager=run_manager, **kwargs @@ -299,27 +311,23 @@ def _generate( ) return self._create_chat_result(response) - def _create_chat_result( - self, response: MistralChatCompletionResponse - ) -> ChatResult: + def _create_chat_result(self, response: Dict) -> ChatResult: generations = [] - for res in response.choices: - finish_reason = getattr(res, "finish_reason") - if finish_reason: - finish_reason = finish_reason.value + for res in response["choices"]: + finish_reason = res.get("finish_reason") gen = ChatGeneration( - message=_convert_mistral_chat_message_to_message(res.message), + message=_convert_mistral_chat_message_to_message(res["message"]), generation_info={"finish_reason": finish_reason}, ) generations.append(gen) - token_usage = getattr(response, "usage") - token_usage = vars(token_usage) if token_usage else {} + token_usage = response.get("usage", {}) + llm_output = {"token_usage": token_usage, "model": self.model} return ChatResult(generations=generations, llm_output=llm_output) def _create_message_dicts( self, messages: List[BaseMessage], stop: Optional[List[str]] - ) -> Tuple[List[MistralChatMessage], Dict[str, Any]]: + ) -> Tuple[List[Dict], Dict[str, Any]]: params = self._client_params if stop is not None or "stop" in params: if "stop" in params: @@ -340,20 +348,24 @@ def _stream( message_dicts, params = self._create_message_dicts(messages, stop) params = {**params, **kwargs, "stream": True} - default_chunk_class = AIMessageChunk + default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk for chunk in self.completion_with_retry( messages=message_dicts, run_manager=run_manager, **params ): - if len(chunk.choices) == 0: + if len(chunk["choices"]) == 0: continue - delta = chunk.choices[0].delta - if not delta.content: + delta = chunk["choices"][0]["delta"] + if not delta["content"]: continue - chunk = _convert_delta_to_message_chunk(delta, default_chunk_class) - default_chunk_class = chunk.__class__ + new_chunk = _convert_delta_to_message_chunk(delta, default_chunk_class) + # make future chunks same type as first chunk + default_chunk_class = new_chunk.__class__ + gen_chunk = ChatGenerationChunk(message=new_chunk) if run_manager: - run_manager.on_llm_new_token(token=chunk.content, chunk=chunk) - yield ChatGenerationChunk(message=chunk) + run_manager.on_llm_new_token( + token=cast(str, new_chunk.content), chunk=gen_chunk + ) + yield gen_chunk async def _astream( self, @@ -365,20 +377,24 @@ async def _astream( message_dicts, params = self._create_message_dicts(messages, stop) params = {**params, **kwargs, "stream": True} - default_chunk_class = AIMessageChunk + default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk async for chunk in await acompletion_with_retry( self, messages=message_dicts, run_manager=run_manager, **params ): - if len(chunk.choices) == 0: + if len(chunk["choices"]) == 0: continue - delta = chunk.choices[0].delta - if not delta.content: + delta = chunk["choices"][0]["delta"] + if not delta["content"]: continue - chunk = _convert_delta_to_message_chunk(delta, default_chunk_class) - default_chunk_class = chunk.__class__ + new_chunk = _convert_delta_to_message_chunk(delta, default_chunk_class) + # make future chunks same type as first chunk + default_chunk_class = new_chunk.__class__ + gen_chunk = ChatGenerationChunk(message=new_chunk) if run_manager: - await run_manager.on_llm_new_token(token=chunk.content, chunk=chunk) - yield ChatGenerationChunk(message=chunk) + await run_manager.on_llm_new_token( + token=cast(str, new_chunk.content), chunk=gen_chunk + ) + yield gen_chunk async def _agenerate( self, diff --git a/libs/partners/mistralai/langchain_mistralai/embeddings.py b/libs/partners/mistralai/langchain_mistralai/embeddings.py index 977dbdf95908d..e58f7d3692ea8 100644 --- a/libs/partners/mistralai/langchain_mistralai/embeddings.py +++ b/libs/partners/mistralai/langchain_mistralai/embeddings.py @@ -2,6 +2,7 @@ import logging from typing import Dict, Iterable, List, Optional +import httpx from langchain_core.embeddings import Embeddings from langchain_core.pydantic_v1 import ( BaseModel, @@ -11,12 +12,6 @@ root_validator, ) from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env -from mistralai.async_client import MistralAsyncClient -from mistralai.client import MistralClient -from mistralai.constants import ( - ENDPOINT as DEFAULT_MISTRAL_ENDPOINT, -) -from mistralai.exceptions import MistralException from tokenizers import Tokenizer # type: ignore logger = logging.getLogger(__name__) @@ -40,10 +35,10 @@ class MistralAIEmbeddings(BaseModel, Embeddings): ) """ - client: MistralClient = Field(default=None) #: :meta private: - async_client: MistralAsyncClient = Field(default=None) #: :meta private: + client: httpx.Client = Field(default=None) #: :meta private: + async_client: httpx.AsyncClient = Field(default=None) #: :meta private: mistral_api_key: Optional[SecretStr] = None - endpoint: str = DEFAULT_MISTRAL_ENDPOINT + endpoint: str = "https://api.mistral.ai/v1/" max_retries: int = 5 timeout: int = 120 max_concurrent_requests: int = 64 @@ -64,18 +59,26 @@ def validate_environment(cls, values: Dict) -> Dict: values, "mistral_api_key", "MISTRAL_API_KEY", default="" ) ) - values["client"] = MistralClient( - api_key=values["mistral_api_key"].get_secret_value(), - endpoint=values["endpoint"], - max_retries=values["max_retries"], + api_key_str = values["mistral_api_key"].get_secret_value() + # todo: handle retries + values["client"] = httpx.Client( + base_url=values["endpoint"], + headers={ + "Content-Type": "application/json", + "Accept": "application/json", + "Authorization": f"Bearer {api_key_str}", + }, timeout=values["timeout"], ) - values["async_client"] = MistralAsyncClient( - api_key=values["mistral_api_key"].get_secret_value(), - endpoint=values["endpoint"], - max_retries=values["max_retries"], + # todo: handle retries and max_concurrency + values["async_client"] = httpx.AsyncClient( + base_url=values["endpoint"], + headers={ + "Content-Type": "application/json", + "Accept": "application/json", + "Authorization": f"Bearer {api_key_str}", + }, timeout=values["timeout"], - max_concurrent_requests=values["max_concurrent_requests"], ) if values["tokenizer"] is None: values["tokenizer"] = Tokenizer.from_pretrained( @@ -115,18 +118,21 @@ def embed_documents(self, texts: List[str]) -> List[List[float]]: """ try: batch_responses = ( - self.client.embeddings( - model=self.model, - input=batch, + self.client.post( + url="/embeddings", + json=dict( + model=self.model, + input=batch, + ), ) for batch in self._get_batches(texts) ) return [ - list(map(float, embedding_obj.embedding)) + list(map(float, embedding_obj["embedding"])) for response in batch_responses - for embedding_obj in response.data + for embedding_obj in response.json()["data"] ] - except MistralException as e: + except Exception as e: logger.error(f"An error occurred with MistralAI: {e}") raise @@ -142,19 +148,22 @@ async def aembed_documents(self, texts: List[str]) -> List[List[float]]: try: batch_responses = await asyncio.gather( *[ - self.async_client.embeddings( - model=self.model, - input=batch, + self.async_client.post( + url="/embeddings", + json=dict( + model=self.model, + input=batch, + ), ) for batch in self._get_batches(texts) ] ) return [ - list(map(float, embedding_obj.embedding)) + list(map(float, embedding_obj["embedding"])) for response in batch_responses - for embedding_obj in response.data + for embedding_obj in response.json()["data"] ] - except MistralException as e: + except Exception as e: logger.error(f"An error occurred with MistralAI: {e}") raise diff --git a/libs/partners/mistralai/poetry.lock b/libs/partners/mistralai/poetry.lock index cb6b81615a51c..0fa4897c9419b 100644 --- a/libs/partners/mistralai/poetry.lock +++ b/libs/partners/mistralai/poetry.lock @@ -206,13 +206,13 @@ typing = ["typing-extensions (>=4.8)"] [[package]] name = "fsspec" -version = "2024.2.0" +version = "2024.3.1" description = "File-system specification" optional = false python-versions = ">=3.8" files = [ - {file = "fsspec-2024.2.0-py3-none-any.whl", hash = "sha256:817f969556fa5916bc682e02ca2045f96ff7f586d45110fcb76022063ad2c7d8"}, - {file = "fsspec-2024.2.0.tar.gz", hash = "sha256:b6ad1a679f760dda52b1168c859d01b7b80648ea6f7f7c7f5a8a91dc3f3ecb84"}, + {file = "fsspec-2024.3.1-py3-none-any.whl", hash = "sha256:918d18d41bf73f0e2b261824baeb1b124bcf771767e3a26425cd7dec3332f512"}, + {file = "fsspec-2024.3.1.tar.gz", hash = "sha256:f39780e282d7d117ffb42bb96992f8a90795e4d0fb0f661a70ca39fe9c43ded9"}, ] [package.extras] @@ -273,13 +273,13 @@ trio = ["trio (>=0.22.0,<0.25.0)"] [[package]] name = "httpx" -version = "0.25.2" +version = "0.27.0" description = "The next generation HTTP client." optional = false python-versions = ">=3.8" files = [ - {file = "httpx-0.25.2-py3-none-any.whl", hash = "sha256:a05d3d052d9b2dfce0e3896636467f8a5342fb2b902c819428e1ac65413ca118"}, - {file = "httpx-0.25.2.tar.gz", hash = "sha256:8b8fcaa0c8ea7b05edd69a094e63a2094c4efcb48129fb757361bc423c0ad9e8"}, + {file = "httpx-0.27.0-py3-none-any.whl", hash = "sha256:71d5465162c13681bff01ad59b2cc68dd838ea1f10e51574bac27103f00c91a5"}, + {file = "httpx-0.27.0.tar.gz", hash = "sha256:a0cb88a46f32dc874e04ee956e4c2764aba2aa228f650b06788ba6bda2962ab5"}, ] [package.dependencies] @@ -295,15 +295,26 @@ cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"] http2 = ["h2 (>=3,<5)"] socks = ["socksio (==1.*)"] +[[package]] +name = "httpx-sse" +version = "0.4.0" +description = "Consume Server-Sent Event (SSE) messages with HTTPX." +optional = false +python-versions = ">=3.8" +files = [ + {file = "httpx-sse-0.4.0.tar.gz", hash = "sha256:1e81a3a3070ce322add1d3529ed42eb5f70817f45ed6ec915ab753f961139721"}, + {file = "httpx_sse-0.4.0-py3-none-any.whl", hash = "sha256:f329af6eae57eaa2bdfd962b42524764af68075ea87370a2de920af5341e318f"}, +] + [[package]] name = "huggingface-hub" -version = "0.20.3" +version = "0.21.4" description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub" optional = false python-versions = ">=3.8.0" files = [ - {file = "huggingface_hub-0.20.3-py3-none-any.whl", hash = "sha256:d988ae4f00d3e307b0c80c6a05ca6dbb7edba8bba3079f74cda7d9c2e562a7b6"}, - {file = "huggingface_hub-0.20.3.tar.gz", hash = "sha256:94e7f8e074475fbc67d6a71957b678e1b4a74ff1b64a644fd6cbb83da962d05d"}, + {file = "huggingface_hub-0.21.4-py3-none-any.whl", hash = "sha256:df37c2c37fc6c82163cdd8a67ede261687d80d1e262526d6c0ce73b6b3630a7b"}, + {file = "huggingface_hub-0.21.4.tar.gz", hash = "sha256:e1f4968c93726565a80edf6dc309763c7b546d0cfe79aa221206034d50155531"}, ] [package.dependencies] @@ -320,11 +331,12 @@ all = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", cli = ["InquirerPy (==0.3.4)"] dev = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "mypy (==1.5.1)", "numpy", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.1.3)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"] fastai = ["fastai (>=2.4)", "fastcore (>=1.3.27)", "toml"] +hf-transfer = ["hf-transfer (>=0.1.4)"] inference = ["aiohttp", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)"] quality = ["mypy (==1.5.1)", "ruff (>=0.1.3)"] tensorflow = ["graphviz", "pydot", "tensorflow"] testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "numpy", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "soundfile", "urllib3 (<2.0)"] -torch = ["torch"] +torch = ["safetensors", "torch"] typing = ["types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)"] [[package]] @@ -376,7 +388,7 @@ files = [ [[package]] name = "langchain-core" -version = "0.1.27" +version = "0.1.33" description = "Building applications with LLMs through composability" optional = false python-versions = ">=3.8.1,<4.0" @@ -402,13 +414,13 @@ url = "../../core" [[package]] name = "langsmith" -version = "0.1.8" +version = "0.1.31" description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform." optional = false -python-versions = ">=3.8.1,<4.0" +python-versions = "<4.0,>=3.8.1" files = [ - {file = "langsmith-0.1.8-py3-none-any.whl", hash = "sha256:f4320fd80ec9d311a648e7d4c44e0814e6e5454772c5026f40db0307bc07e287"}, - {file = "langsmith-0.1.8.tar.gz", hash = "sha256:ab5f1cdfb7d418109ea506d41928fb8708547db2f6c7f7da7cfe997f3c55767b"}, + {file = "langsmith-0.1.31-py3-none-any.whl", hash = "sha256:5211a9dc00831db307eb843485a97096484b697b5d2cd1efaac34228e97ca087"}, + {file = "langsmith-0.1.31.tar.gz", hash = "sha256:efd54ccd44be7fda911bfdc0ead340473df2fdd07345c7252901834d0c4aa37e"}, ] [package.dependencies] @@ -416,40 +428,6 @@ orjson = ">=3.9.14,<4.0.0" pydantic = ">=1,<3" requests = ">=2,<3" -[[package]] -name = "mistralai" -version = "0.0.12" -description = "" -optional = false -python-versions = ">=3.8,<4.0" -files = [ - {file = "mistralai-0.0.12-py3-none-any.whl", hash = "sha256:d489d1f0a31bf0edbe15c6d12f68b943148d2a725a088be0d8a5d4c888f8436c"}, - {file = "mistralai-0.0.12.tar.gz", hash = "sha256:fe652836146a15bdce7691a95803a32c53c641c5400093447ffa93bf2ed296b2"}, -] - -[package.dependencies] -httpx = ">=0.25.2,<0.26.0" -orjson = ">=3.9.10,<4.0.0" -pydantic = ">=2.5.2,<3.0.0" - -[[package]] -name = "mistralai" -version = "0.1.2" -description = "" -optional = false -python-versions = ">=3.9,<4.0" -files = [ - {file = "mistralai-0.1.2-py3-none-any.whl", hash = "sha256:5e74e5ef0c0f15058892d73b00c659e06e9882c00838a1ad9862d93c77336847"}, - {file = "mistralai-0.1.2.tar.gz", hash = "sha256:eb915fd15075f71bdbfce9cb476bb647322b1ce1e93b19ab0047728067466397"}, -] - -[package.dependencies] -httpx = ">=0.25.2,<0.26.0" -orjson = ">=3.9.10,<4.0.0" -pandas = ">=2.2.0,<3.0.0" -pyarrow = ">=15.0.0,<16.0.0" -pydantic = ">=2.5.2,<3.0.0" - [[package]] name = "mypy" version = "0.991" @@ -511,51 +489,6 @@ files = [ {file = "mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782"}, ] -[[package]] -name = "numpy" -version = "1.26.4" -description = "Fundamental package for array computing in Python" -optional = false -python-versions = ">=3.9" -files = [ - {file = "numpy-1.26.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:9ff0f4f29c51e2803569d7a51c2304de5554655a60c5d776e35b4a41413830d0"}, - {file = "numpy-1.26.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2e4ee3380d6de9c9ec04745830fd9e2eccb3e6cf790d39d7b98ffd19b0dd754a"}, - {file = "numpy-1.26.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d209d8969599b27ad20994c8e41936ee0964e6da07478d6c35016bc386b66ad4"}, - {file = "numpy-1.26.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ffa75af20b44f8dba823498024771d5ac50620e6915abac414251bd971b4529f"}, - {file = "numpy-1.26.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:62b8e4b1e28009ef2846b4c7852046736bab361f7aeadeb6a5b89ebec3c7055a"}, - {file = "numpy-1.26.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a4abb4f9001ad2858e7ac189089c42178fcce737e4169dc61321660f1a96c7d2"}, - {file = "numpy-1.26.4-cp310-cp310-win32.whl", hash = "sha256:bfe25acf8b437eb2a8b2d49d443800a5f18508cd811fea3181723922a8a82b07"}, - {file = "numpy-1.26.4-cp310-cp310-win_amd64.whl", hash = "sha256:b97fe8060236edf3662adfc2c633f56a08ae30560c56310562cb4f95500022d5"}, - {file = "numpy-1.26.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4c66707fabe114439db9068ee468c26bbdf909cac0fb58686a42a24de1760c71"}, - {file = "numpy-1.26.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:edd8b5fe47dab091176d21bb6de568acdd906d1887a4584a15a9a96a1dca06ef"}, - {file = "numpy-1.26.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7ab55401287bfec946ced39700c053796e7cc0e3acbef09993a9ad2adba6ca6e"}, - {file = "numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:666dbfb6ec68962c033a450943ded891bed2d54e6755e35e5835d63f4f6931d5"}, - {file = "numpy-1.26.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:96ff0b2ad353d8f990b63294c8986f1ec3cb19d749234014f4e7eb0112ceba5a"}, - {file = "numpy-1.26.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:60dedbb91afcbfdc9bc0b1f3f402804070deed7392c23eb7a7f07fa857868e8a"}, - {file = "numpy-1.26.4-cp311-cp311-win32.whl", hash = "sha256:1af303d6b2210eb850fcf03064d364652b7120803a0b872f5211f5234b399f20"}, - {file = "numpy-1.26.4-cp311-cp311-win_amd64.whl", hash = "sha256:cd25bcecc4974d09257ffcd1f098ee778f7834c3ad767fe5db785be9a4aa9cb2"}, - {file = "numpy-1.26.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:b3ce300f3644fb06443ee2222c2201dd3a89ea6040541412b8fa189341847218"}, - {file = "numpy-1.26.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:03a8c78d01d9781b28a6989f6fa1bb2c4f2d51201cf99d3dd875df6fbd96b23b"}, - {file = "numpy-1.26.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9fad7dcb1aac3c7f0584a5a8133e3a43eeb2fe127f47e3632d43d677c66c102b"}, - {file = "numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:675d61ffbfa78604709862923189bad94014bef562cc35cf61d3a07bba02a7ed"}, - {file = "numpy-1.26.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:ab47dbe5cc8210f55aa58e4805fe224dac469cde56b9f731a4c098b91917159a"}, - {file = "numpy-1.26.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:1dda2e7b4ec9dd512f84935c5f126c8bd8b9f2fc001e9f54af255e8c5f16b0e0"}, - {file = "numpy-1.26.4-cp312-cp312-win32.whl", hash = "sha256:50193e430acfc1346175fcbdaa28ffec49947a06918b7b92130744e81e640110"}, - {file = "numpy-1.26.4-cp312-cp312-win_amd64.whl", hash = "sha256:08beddf13648eb95f8d867350f6a018a4be2e5ad54c8d8caed89ebca558b2818"}, - {file = "numpy-1.26.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:7349ab0fa0c429c82442a27a9673fc802ffdb7c7775fad780226cb234965e53c"}, - {file = "numpy-1.26.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:52b8b60467cd7dd1e9ed082188b4e6bb35aa5cdd01777621a1658910745b90be"}, - {file = "numpy-1.26.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d5241e0a80d808d70546c697135da2c613f30e28251ff8307eb72ba696945764"}, - {file = "numpy-1.26.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f870204a840a60da0b12273ef34f7051e98c3b5961b61b0c2c1be6dfd64fbcd3"}, - {file = "numpy-1.26.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:679b0076f67ecc0138fd2ede3a8fd196dddc2ad3254069bcb9faf9a79b1cebcd"}, - {file = "numpy-1.26.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:47711010ad8555514b434df65f7d7b076bb8261df1ca9bb78f53d3b2db02e95c"}, - {file = "numpy-1.26.4-cp39-cp39-win32.whl", hash = "sha256:a354325ee03388678242a4d7ebcd08b5c727033fcff3b2f536aea978e15ee9e6"}, - {file = "numpy-1.26.4-cp39-cp39-win_amd64.whl", hash = "sha256:3373d5d70a5fe74a2c1bb6d2cfd9609ecf686d47a2d7b1d37a8f3b6bf6003aea"}, - {file = "numpy-1.26.4-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:afedb719a9dcfc7eaf2287b839d8198e06dcd4cb5d276a3df279231138e83d30"}, - {file = "numpy-1.26.4-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95a7476c59002f2f6c590b9b7b998306fba6a5aa646b1e22ddfeaf8f78c3a29c"}, - {file = "numpy-1.26.4-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:7e50d0a0cc3189f9cb0aeb3a6a6af18c16f59f004b866cd2be1c14b36134a4a0"}, - {file = "numpy-1.26.4.tar.gz", hash = "sha256:2a02aba9ed12e4ac4eb3ea9421c420301a0c6460d9830d74a9df87efa4912010"}, -] - [[package]] name = "orjson" version = "3.9.15" @@ -626,79 +559,6 @@ files = [ {file = "packaging-23.2.tar.gz", hash = "sha256:048fb0e9405036518eaaf48a55953c750c11e1a1b68e0dd1a9d62ed0c092cfc5"}, ] -[[package]] -name = "pandas" -version = "2.2.1" -description = "Powerful data structures for data analysis, time series, and statistics" -optional = false -python-versions = ">=3.9" -files = [ - {file = "pandas-2.2.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:8df8612be9cd1c7797c93e1c5df861b2ddda0b48b08f2c3eaa0702cf88fb5f88"}, - {file = "pandas-2.2.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:0f573ab277252ed9aaf38240f3b54cfc90fff8e5cab70411ee1d03f5d51f3944"}, - {file = "pandas-2.2.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f02a3a6c83df4026e55b63c1f06476c9aa3ed6af3d89b4f04ea656ccdaaaa359"}, - {file = "pandas-2.2.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c38ce92cb22a4bea4e3929429aa1067a454dcc9c335799af93ba9be21b6beb51"}, - {file = "pandas-2.2.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:c2ce852e1cf2509a69e98358e8458775f89599566ac3775e70419b98615f4b06"}, - {file = "pandas-2.2.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:53680dc9b2519cbf609c62db3ed7c0b499077c7fefda564e330286e619ff0dd9"}, - {file = "pandas-2.2.1-cp310-cp310-win_amd64.whl", hash = "sha256:94e714a1cca63e4f5939cdce5f29ba8d415d85166be3441165edd427dc9f6bc0"}, - {file = "pandas-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:f821213d48f4ab353d20ebc24e4faf94ba40d76680642fb7ce2ea31a3ad94f9b"}, - {file = "pandas-2.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:c70e00c2d894cb230e5c15e4b1e1e6b2b478e09cf27cc593a11ef955b9ecc81a"}, - {file = "pandas-2.2.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e97fbb5387c69209f134893abc788a6486dbf2f9e511070ca05eed4b930b1b02"}, - {file = "pandas-2.2.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:101d0eb9c5361aa0146f500773395a03839a5e6ecde4d4b6ced88b7e5a1a6403"}, - {file = "pandas-2.2.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:7d2ed41c319c9fb4fd454fe25372028dfa417aacb9790f68171b2e3f06eae8cd"}, - {file = "pandas-2.2.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:af5d3c00557d657c8773ef9ee702c61dd13b9d7426794c9dfeb1dc4a0bf0ebc7"}, - {file = "pandas-2.2.1-cp311-cp311-win_amd64.whl", hash = "sha256:06cf591dbaefb6da9de8472535b185cba556d0ce2e6ed28e21d919704fef1a9e"}, - {file = "pandas-2.2.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:88ecb5c01bb9ca927ebc4098136038519aa5d66b44671861ffab754cae75102c"}, - {file = "pandas-2.2.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:04f6ec3baec203c13e3f8b139fb0f9f86cd8c0b94603ae3ae8ce9a422e9f5bee"}, - {file = "pandas-2.2.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a935a90a76c44fe170d01e90a3594beef9e9a6220021acfb26053d01426f7dc2"}, - {file = "pandas-2.2.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c391f594aae2fd9f679d419e9a4d5ba4bce5bb13f6a989195656e7dc4b95c8f0"}, - {file = "pandas-2.2.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:9d1265545f579edf3f8f0cb6f89f234f5e44ba725a34d86535b1a1d38decbccc"}, - {file = "pandas-2.2.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:11940e9e3056576ac3244baef2fedade891977bcc1cb7e5cc8f8cc7d603edc89"}, - {file = "pandas-2.2.1-cp312-cp312-win_amd64.whl", hash = "sha256:4acf681325ee1c7f950d058b05a820441075b0dd9a2adf5c4835b9bc056bf4fb"}, - {file = "pandas-2.2.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9bd8a40f47080825af4317d0340c656744f2bfdb6819f818e6ba3cd24c0e1397"}, - {file = "pandas-2.2.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:df0c37ebd19e11d089ceba66eba59a168242fc6b7155cba4ffffa6eccdfb8f16"}, - {file = "pandas-2.2.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:739cc70eaf17d57608639e74d63387b0d8594ce02f69e7a0b046f117974b3019"}, - {file = "pandas-2.2.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f9d3558d263073ed95e46f4650becff0c5e1ffe0fc3a015de3c79283dfbdb3df"}, - {file = "pandas-2.2.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:4aa1d8707812a658debf03824016bf5ea0d516afdea29b7dc14cf687bc4d4ec6"}, - {file = "pandas-2.2.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:76f27a809cda87e07f192f001d11adc2b930e93a2b0c4a236fde5429527423be"}, - {file = "pandas-2.2.1-cp39-cp39-win_amd64.whl", hash = "sha256:1ba21b1d5c0e43416218db63037dbe1a01fc101dc6e6024bcad08123e48004ab"}, - {file = "pandas-2.2.1.tar.gz", hash = "sha256:0ab90f87093c13f3e8fa45b48ba9f39181046e8f3317d3aadb2fffbb1b978572"}, -] - -[package.dependencies] -numpy = [ - {version = ">=1.22.4,<2", markers = "python_version < \"3.11\""}, - {version = ">=1.23.2,<2", markers = "python_version == \"3.11\""}, - {version = ">=1.26.0,<2", markers = "python_version >= \"3.12\""}, -] -python-dateutil = ">=2.8.2" -pytz = ">=2020.1" -tzdata = ">=2022.7" - -[package.extras] -all = ["PyQt5 (>=5.15.9)", "SQLAlchemy (>=2.0.0)", "adbc-driver-postgresql (>=0.8.0)", "adbc-driver-sqlite (>=0.8.0)", "beautifulsoup4 (>=4.11.2)", "bottleneck (>=1.3.6)", "dataframe-api-compat (>=0.1.7)", "fastparquet (>=2022.12.0)", "fsspec (>=2022.11.0)", "gcsfs (>=2022.11.0)", "html5lib (>=1.1)", "hypothesis (>=6.46.1)", "jinja2 (>=3.1.2)", "lxml (>=4.9.2)", "matplotlib (>=3.6.3)", "numba (>=0.56.4)", "numexpr (>=2.8.4)", "odfpy (>=1.4.1)", "openpyxl (>=3.1.0)", "pandas-gbq (>=0.19.0)", "psycopg2 (>=2.9.6)", "pyarrow (>=10.0.1)", "pymysql (>=1.0.2)", "pyreadstat (>=1.2.0)", "pytest (>=7.3.2)", "pytest-xdist (>=2.2.0)", "python-calamine (>=0.1.7)", "pyxlsb (>=1.0.10)", "qtpy (>=2.3.0)", "s3fs (>=2022.11.0)", "scipy (>=1.10.0)", "tables (>=3.8.0)", "tabulate (>=0.9.0)", "xarray (>=2022.12.0)", "xlrd (>=2.0.1)", "xlsxwriter (>=3.0.5)", "zstandard (>=0.19.0)"] -aws = ["s3fs (>=2022.11.0)"] -clipboard = ["PyQt5 (>=5.15.9)", "qtpy (>=2.3.0)"] -compression = ["zstandard (>=0.19.0)"] -computation = ["scipy (>=1.10.0)", "xarray (>=2022.12.0)"] -consortium-standard = ["dataframe-api-compat (>=0.1.7)"] -excel = ["odfpy (>=1.4.1)", "openpyxl (>=3.1.0)", "python-calamine (>=0.1.7)", "pyxlsb (>=1.0.10)", "xlrd (>=2.0.1)", "xlsxwriter (>=3.0.5)"] -feather = ["pyarrow (>=10.0.1)"] -fss = ["fsspec (>=2022.11.0)"] -gcp = ["gcsfs (>=2022.11.0)", "pandas-gbq (>=0.19.0)"] -hdf5 = ["tables (>=3.8.0)"] -html = ["beautifulsoup4 (>=4.11.2)", "html5lib (>=1.1)", "lxml (>=4.9.2)"] -mysql = ["SQLAlchemy (>=2.0.0)", "pymysql (>=1.0.2)"] -output-formatting = ["jinja2 (>=3.1.2)", "tabulate (>=0.9.0)"] -parquet = ["pyarrow (>=10.0.1)"] -performance = ["bottleneck (>=1.3.6)", "numba (>=0.56.4)", "numexpr (>=2.8.4)"] -plot = ["matplotlib (>=3.6.3)"] -postgresql = ["SQLAlchemy (>=2.0.0)", "adbc-driver-postgresql (>=0.8.0)", "psycopg2 (>=2.9.6)"] -pyarrow = ["pyarrow (>=10.0.1)"] -spss = ["pyreadstat (>=1.2.0)"] -sql-other = ["SQLAlchemy (>=2.0.0)", "adbc-driver-postgresql (>=0.8.0)", "adbc-driver-sqlite (>=0.8.0)"] -test = ["hypothesis (>=6.46.1)", "pytest (>=7.3.2)", "pytest-xdist (>=2.2.0)"] -xml = ["lxml (>=4.9.2)"] - [[package]] name = "pluggy" version = "1.4.0" @@ -714,63 +574,15 @@ files = [ dev = ["pre-commit", "tox"] testing = ["pytest", "pytest-benchmark"] -[[package]] -name = "pyarrow" -version = "15.0.0" -description = "Python library for Apache Arrow" -optional = false -python-versions = ">=3.8" -files = [ - {file = "pyarrow-15.0.0-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:0a524532fd6dd482edaa563b686d754c70417c2f72742a8c990b322d4c03a15d"}, - {file = "pyarrow-15.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:60a6bdb314affa9c2e0d5dddf3d9cbb9ef4a8dddaa68669975287d47ece67642"}, - {file = "pyarrow-15.0.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:66958fd1771a4d4b754cd385835e66a3ef6b12611e001d4e5edfcef5f30391e2"}, - {file = "pyarrow-15.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1f500956a49aadd907eaa21d4fff75f73954605eaa41f61cb94fb008cf2e00c6"}, - {file = "pyarrow-15.0.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:6f87d9c4f09e049c2cade559643424da84c43a35068f2a1c4653dc5b1408a929"}, - {file = "pyarrow-15.0.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:85239b9f93278e130d86c0e6bb455dcb66fc3fd891398b9d45ace8799a871a1e"}, - {file = "pyarrow-15.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:5b8d43e31ca16aa6e12402fcb1e14352d0d809de70edd185c7650fe80e0769e3"}, - {file = "pyarrow-15.0.0-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:fa7cd198280dbd0c988df525e50e35b5d16873e2cdae2aaaa6363cdb64e3eec5"}, - {file = "pyarrow-15.0.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:8780b1a29d3c8b21ba6b191305a2a607de2e30dab399776ff0aa09131e266340"}, - {file = "pyarrow-15.0.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fe0ec198ccc680f6c92723fadcb97b74f07c45ff3fdec9dd765deb04955ccf19"}, - {file = "pyarrow-15.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:036a7209c235588c2f07477fe75c07e6caced9b7b61bb897c8d4e52c4b5f9555"}, - {file = "pyarrow-15.0.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:2bd8a0e5296797faf9a3294e9fa2dc67aa7f10ae2207920dbebb785c77e9dbe5"}, - {file = "pyarrow-15.0.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:e8ebed6053dbe76883a822d4e8da36860f479d55a762bd9e70d8494aed87113e"}, - {file = "pyarrow-15.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:17d53a9d1b2b5bd7d5e4cd84d018e2a45bc9baaa68f7e6e3ebed45649900ba99"}, - {file = "pyarrow-15.0.0-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:9950a9c9df24090d3d558b43b97753b8f5867fb8e521f29876aa021c52fda351"}, - {file = "pyarrow-15.0.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:003d680b5e422d0204e7287bb3fa775b332b3fce2996aa69e9adea23f5c8f970"}, - {file = "pyarrow-15.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f75fce89dad10c95f4bf590b765e3ae98bcc5ba9f6ce75adb828a334e26a3d40"}, - {file = "pyarrow-15.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0ca9cb0039923bec49b4fe23803807e4ef39576a2bec59c32b11296464623dc2"}, - {file = "pyarrow-15.0.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:9ed5a78ed29d171d0acc26a305a4b7f83c122d54ff5270810ac23c75813585e4"}, - {file = "pyarrow-15.0.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:6eda9e117f0402dfcd3cd6ec9bfee89ac5071c48fc83a84f3075b60efa96747f"}, - {file = "pyarrow-15.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:9a3a6180c0e8f2727e6f1b1c87c72d3254cac909e609f35f22532e4115461177"}, - {file = "pyarrow-15.0.0-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:19a8918045993349b207de72d4576af0191beef03ea655d8bdb13762f0cd6eac"}, - {file = "pyarrow-15.0.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:d0ec076b32bacb6666e8813a22e6e5a7ef1314c8069d4ff345efa6246bc38593"}, - {file = "pyarrow-15.0.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5db1769e5d0a77eb92344c7382d6543bea1164cca3704f84aa44e26c67e320fb"}, - {file = "pyarrow-15.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e2617e3bf9df2a00020dd1c1c6dce5cc343d979efe10bc401c0632b0eef6ef5b"}, - {file = "pyarrow-15.0.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:d31c1d45060180131caf10f0f698e3a782db333a422038bf7fe01dace18b3a31"}, - {file = "pyarrow-15.0.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:c8c287d1d479de8269398b34282e206844abb3208224dbdd7166d580804674b7"}, - {file = "pyarrow-15.0.0-cp38-cp38-win_amd64.whl", hash = "sha256:07eb7f07dc9ecbb8dace0f58f009d3a29ee58682fcdc91337dfeb51ea618a75b"}, - {file = "pyarrow-15.0.0-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:47af7036f64fce990bb8a5948c04722e4e3ea3e13b1007ef52dfe0aa8f23cf7f"}, - {file = "pyarrow-15.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:93768ccfff85cf044c418bfeeafce9a8bb0cee091bd8fd19011aff91e58de540"}, - {file = "pyarrow-15.0.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f6ee87fd6892700960d90abb7b17a72a5abb3b64ee0fe8db6c782bcc2d0dc0b4"}, - {file = "pyarrow-15.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:001fca027738c5f6be0b7a3159cc7ba16a5c52486db18160909a0831b063c4e4"}, - {file = "pyarrow-15.0.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:d1c48648f64aec09accf44140dccb92f4f94394b8d79976c426a5b79b11d4fa7"}, - {file = "pyarrow-15.0.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:972a0141be402bb18e3201448c8ae62958c9c7923dfaa3b3d4530c835ac81aed"}, - {file = "pyarrow-15.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:f01fc5cf49081426429127aa2d427d9d98e1cb94a32cb961d583a70b7c4504e6"}, - {file = "pyarrow-15.0.0.tar.gz", hash = "sha256:876858f549d540898f927eba4ef77cd549ad8d24baa3207cf1b72e5788b50e83"}, -] - -[package.dependencies] -numpy = ">=1.16.6,<2" - [[package]] name = "pydantic" -version = "2.6.2" +version = "2.6.4" description = "Data validation using Python type hints" optional = false python-versions = ">=3.8" files = [ - {file = "pydantic-2.6.2-py3-none-any.whl", hash = "sha256:37a5432e54b12fecaa1049c5195f3d860a10e01bdfd24f1840ef14bd0d3aeab3"}, - {file = "pydantic-2.6.2.tar.gz", hash = "sha256:a09be1c3d28f3abe37f8a78af58284b236a92ce520105ddc91a6d29ea1176ba7"}, + {file = "pydantic-2.6.4-py3-none-any.whl", hash = "sha256:cc46fce86607580867bdc3361ad462bab9c222ef042d3da86f2fb333e1d916c5"}, + {file = "pydantic-2.6.4.tar.gz", hash = "sha256:b1704e0847db01817624a6b86766967f552dd9dbf3afba4004409f908dcc84e6"}, ] [package.dependencies] @@ -912,31 +724,6 @@ pytest = ">=7.0.0" docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"] testing = ["coverage (>=6.2)", "flaky (>=3.5.0)", "hypothesis (>=5.7.1)", "mypy (>=0.931)", "pytest-trio (>=0.7.0)"] -[[package]] -name = "python-dateutil" -version = "2.8.2" -description = "Extensions to the standard Python datetime module" -optional = false -python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" -files = [ - {file = "python-dateutil-2.8.2.tar.gz", hash = "sha256:0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86"}, - {file = "python_dateutil-2.8.2-py2.py3-none-any.whl", hash = "sha256:961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9"}, -] - -[package.dependencies] -six = ">=1.5" - -[[package]] -name = "pytz" -version = "2024.1" -description = "World timezone definitions, modern and historical" -optional = false -python-versions = "*" -files = [ - {file = "pytz-2024.1-py2.py3-none-any.whl", hash = "sha256:328171f4e3623139da4983451950b28e95ac706e13f3f2630a879749e7a8b319"}, - {file = "pytz-2024.1.tar.gz", hash = "sha256:2a29735ea9c18baf14b448846bde5a48030ed267578472d8955cd0e7443a9812"}, -] - [[package]] name = "pyyaml" version = "6.0.1" @@ -1044,17 +831,6 @@ files = [ {file = "ruff-0.1.15.tar.gz", hash = "sha256:f6dfa8c1b21c913c326919056c390966648b680966febcb796cc9d1aaab8564e"}, ] -[[package]] -name = "six" -version = "1.16.0" -description = "Python 2 and 3 compatibility utilities" -optional = false -python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*" -files = [ - {file = "six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254"}, - {file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"}, -] - [[package]] name = "sniffio" version = "1.3.1" @@ -1249,17 +1025,6 @@ files = [ {file = "typing_extensions-4.10.0.tar.gz", hash = "sha256:b0abd7c89e8fb96f98db18d86106ff1d90ab692004eb746cf6eda2682f91b3cb"}, ] -[[package]] -name = "tzdata" -version = "2024.1" -description = "Provider of IANA time zone data" -optional = false -python-versions = ">=2" -files = [ - {file = "tzdata-2024.1-py2.py3-none-any.whl", hash = "sha256:9068bc196136463f5245e51efda838afa15aaeca9903f49050dfa2679db4d252"}, - {file = "tzdata-2024.1.tar.gz", hash = "sha256:2674120f8d891909751c38abcdfd386ac0a5a1127954fbc332af6b5ceae07efd"}, -] - [[package]] name = "urllib3" version = "2.2.1" @@ -1280,4 +1045,4 @@ zstd = ["zstandard (>=0.18.0)"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "ccb95664a734631dde949975506ab160f65cdd222b28bf4f702fb4b11644f418" +content-hash = "706b13139d3f36b3fffb311155ec5bba970f24a692146f7deed08cb8cfe5c962" diff --git a/libs/partners/mistralai/pyproject.toml b/libs/partners/mistralai/pyproject.toml index f5347a67759e3..33fe734b3dd41 100644 --- a/libs/partners/mistralai/pyproject.toml +++ b/libs/partners/mistralai/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "langchain-mistralai" -version = "0.0.5" +version = "0.1.0rc1" description = "An integration package connecting Mistral and LangChain" authors = [] readme = "README.md" @@ -12,9 +12,10 @@ license = "MIT" [tool.poetry.dependencies] python = ">=3.8.1,<4.0" -langchain-core = "^0.1.27" -mistralai = [{version = "^0.1", python = "^3.9"}, {version = ">=0.0.11,<0.2", python="3.8"}] +langchain-core = "^0.1.31" tokenizers = "^0.15.1" +httpx = ">=0.25.2,<1" +httpx-sse = ">=0.3.1,<1" [tool.poetry.group.test] optional = true @@ -24,16 +25,16 @@ pytest = "^7.3.0" pytest-asyncio = "^0.21.1" langchain-core = { path = "../../core", develop = true } -[tool.poetry.group.codespell] +[tool.poetry.group.test_integration] optional = true -[tool.poetry.group.codespell.dependencies] -codespell = "^2.2.0" +[tool.poetry.group.test_integration.dependencies] -[tool.poetry.group.test_integration] +[tool.poetry.group.codespell] optional = true -[tool.poetry.group.test_integration.dependencies] +[tool.poetry.group.codespell.dependencies] +codespell = "^2.2.0" [tool.poetry.group.lint] optional = true diff --git a/libs/partners/mistralai/tests/integration_tests/test_chat_models.py b/libs/partners/mistralai/tests/integration_tests/test_chat_models.py index 1bdd99305ba7f..d4086643ebcb0 100644 --- a/libs/partners/mistralai/tests/integration_tests/test_chat_models.py +++ b/libs/partners/mistralai/tests/integration_tests/test_chat_models.py @@ -1,4 +1,5 @@ """Test ChatMistral chat model.""" + from langchain_mistralai.chat_models import ChatMistralAI @@ -61,3 +62,24 @@ def test_invoke() -> None: result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"])) assert isinstance(result.content, str) + + +def test_structured_output() -> None: + llm = ChatMistralAI(model="mistral-large-latest", temperature=0) + schema = { + "title": "AnswerWithJustification", + "description": ( + "An answer to the user question along with justification for the answer." + ), + "type": "object", + "properties": { + "answer": {"title": "Answer", "type": "string"}, + "justification": {"title": "Justification", "type": "string"}, + }, + "required": ["answer", "justification"], + } + structured_llm = llm.with_structured_output(schema) + result = structured_llm.invoke( + "What weighs more a pound of bricks or a pound of feathers" + ) + assert isinstance(result, dict) diff --git a/libs/partners/mistralai/tests/unit_tests/test_chat_models.py b/libs/partners/mistralai/tests/unit_tests/test_chat_models.py index 8a28a916f591e..f7aa3a749ab30 100644 --- a/libs/partners/mistralai/tests/unit_tests/test_chat_models.py +++ b/libs/partners/mistralai/tests/unit_tests/test_chat_models.py @@ -1,6 +1,7 @@ """Test MistralAI Chat API wrapper.""" + import os -from typing import Any, AsyncGenerator, Generator +from typing import Any, AsyncGenerator, Dict, Generator from unittest.mock import patch import pytest @@ -13,16 +14,6 @@ SystemMessage, ) -# TODO: Remove 'type: ignore' once mistralai has stubs or py.typed marker. -from mistralai.models.chat_completion import ( # type: ignore[import] - ChatCompletionResponseStreamChoice, - ChatCompletionStreamResponse, - DeltaMessage, -) -from mistralai.models.chat_completion import ( - ChatMessage as MistralChatMessage, -) - from langchain_mistralai.chat_models import ( # type: ignore[import] ChatMistralAI, _convert_message_to_mistral_chat_message, @@ -31,13 +22,11 @@ os.environ["MISTRAL_API_KEY"] = "foo" -@pytest.mark.requires("mistralai") def test_mistralai_model_param() -> None: llm = ChatMistralAI(model="foo") assert llm.model == "foo" -@pytest.mark.requires("mistralai") def test_mistralai_initialization() -> None: """Test ChatMistralAI initialization.""" # Verify that ChatMistralAI can be initialized using a secret key provided @@ -50,37 +39,37 @@ def test_mistralai_initialization() -> None: [ ( SystemMessage(content="Hello"), - MistralChatMessage(role="system", content="Hello"), + dict(role="system", content="Hello"), ), ( HumanMessage(content="Hello"), - MistralChatMessage(role="user", content="Hello"), + dict(role="user", content="Hello"), ), ( AIMessage(content="Hello"), - MistralChatMessage(role="assistant", content="Hello"), + dict(role="assistant", content="Hello", tool_calls=None), ), ( ChatMessage(role="assistant", content="Hello"), - MistralChatMessage(role="assistant", content="Hello"), + dict(role="assistant", content="Hello"), ), ], ) def test_convert_message_to_mistral_chat_message( - message: BaseMessage, expected: MistralChatMessage + message: BaseMessage, expected: Dict ) -> None: result = _convert_message_to_mistral_chat_message(message) assert result == expected -def _make_completion_response_from_token(token: str) -> ChatCompletionStreamResponse: - return ChatCompletionStreamResponse( +def _make_completion_response_from_token(token: str) -> Dict: + return dict( id="abc123", model="fake_model", choices=[ - ChatCompletionResponseStreamChoice( + dict( index=0, - delta=DeltaMessage(content=token), + delta=dict(content=token), finish_reason=None, ) ], @@ -88,13 +77,19 @@ def _make_completion_response_from_token(token: str) -> ChatCompletionStreamResp def mock_chat_stream(*args: Any, **kwargs: Any) -> Generator: - for token in ["Hello", " how", " can", " I", " help", "?"]: - yield _make_completion_response_from_token(token) + def it() -> Generator: + for token in ["Hello", " how", " can", " I", " help", "?"]: + yield _make_completion_response_from_token(token) + + return it() async def mock_chat_astream(*args: Any, **kwargs: Any) -> AsyncGenerator: - for token in ["Hello", " how", " can", " I", " help", "?"]: - yield _make_completion_response_from_token(token) + async def it() -> AsyncGenerator: + for token in ["Hello", " how", " can", " I", " help", "?"]: + yield _make_completion_response_from_token(token) + + return it() class MyCustomHandler(BaseCallbackHandler): @@ -104,7 +99,10 @@ def on_llm_new_token(self, token: str, **kwargs: Any) -> None: self.last_token = token -@patch("mistralai.client.MistralClient.chat_stream", new=mock_chat_stream) +@patch( + "langchain_mistralai.chat_models.ChatMistralAI.completion_with_retry", + new=mock_chat_stream, +) def test_stream_with_callback() -> None: callback = MyCustomHandler() chat = ChatMistralAI(callbacks=[callback]) @@ -112,7 +110,7 @@ def test_stream_with_callback() -> None: assert callback.last_token == token.content -@patch("mistralai.async_client.MistralAsyncClient.chat_stream", new=mock_chat_astream) +@patch("langchain_mistralai.chat_models.acompletion_with_retry", new=mock_chat_astream) async def test_astream_with_callback() -> None: callback = MyCustomHandler() chat = ChatMistralAI(callbacks=[callback]) diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index ed97a6c73794c..4b8ec3e016ea6 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -141,14 +141,8 @@ def _convert_message_to_dict(message: BaseMessage) -> dict: message_dict: Dict[str, Any] = { "content": message.content, } - if message.name is not None: - message_dict["name"] = message.name - elif ( - "name" in message.additional_kwargs - and message.additional_kwargs["name"] is not None - ): - # fall back on additional kwargs for backwards compatibility - message_dict["name"] = message.additional_kwargs["name"] + if (name := message.name or message.additional_kwargs.get("name")) is not None: + message_dict["name"] = name # populate role and additional message data if isinstance(message, ChatMessage): @@ -175,9 +169,8 @@ def _convert_message_to_dict(message: BaseMessage) -> dict: message_dict["role"] = "tool" message_dict["tool_call_id"] = message.tool_call_id - # tool message doesn't have name: https://platform.openai.com/docs/api-reference/chat/create#chat-create-messages - if message_dict["name"] is None: - del message_dict["name"] + supported_props = {"content", "role", "tool_call_id"} + message_dict = {k: v for k, v in message_dict.items() if k in supported_props} else: raise TypeError(f"Got unknown type {message}") return message_dict diff --git a/libs/partners/openai/poetry.lock b/libs/partners/openai/poetry.lock index a5f802c763555..fb3b112cbf65b 100644 --- a/libs/partners/openai/poetry.lock +++ b/libs/partners/openai/poetry.lock @@ -174,6 +174,73 @@ files = [ {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, ] +[[package]] +name = "coverage" +version = "7.4.4" +description = "Code coverage measurement for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "coverage-7.4.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e0be5efd5127542ef31f165de269f77560d6cdef525fffa446de6f7e9186cfb2"}, + {file = "coverage-7.4.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ccd341521be3d1b3daeb41960ae94a5e87abe2f46f17224ba5d6f2b8398016cf"}, + {file = "coverage-7.4.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:09fa497a8ab37784fbb20ab699c246053ac294d13fc7eb40ec007a5043ec91f8"}, + {file = "coverage-7.4.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b1a93009cb80730c9bca5d6d4665494b725b6e8e157c1cb7f2db5b4b122ea562"}, + {file = "coverage-7.4.4-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:690db6517f09336559dc0b5f55342df62370a48f5469fabf502db2c6d1cffcd2"}, + {file = "coverage-7.4.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:09c3255458533cb76ef55da8cc49ffab9e33f083739c8bd4f58e79fecfe288f7"}, + {file = "coverage-7.4.4-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:8ce1415194b4a6bd0cdcc3a1dfbf58b63f910dcb7330fe15bdff542c56949f87"}, + {file = "coverage-7.4.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:b91cbc4b195444e7e258ba27ac33769c41b94967919f10037e6355e998af255c"}, + {file = "coverage-7.4.4-cp310-cp310-win32.whl", hash = "sha256:598825b51b81c808cb6f078dcb972f96af96b078faa47af7dfcdf282835baa8d"}, + {file = "coverage-7.4.4-cp310-cp310-win_amd64.whl", hash = "sha256:09ef9199ed6653989ebbcaacc9b62b514bb63ea2f90256e71fea3ed74bd8ff6f"}, + {file = "coverage-7.4.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:0f9f50e7ef2a71e2fae92774c99170eb8304e3fdf9c8c3c7ae9bab3e7229c5cf"}, + {file = "coverage-7.4.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:623512f8ba53c422fcfb2ce68362c97945095b864cda94a92edbaf5994201083"}, + {file = "coverage-7.4.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0513b9508b93da4e1716744ef6ebc507aff016ba115ffe8ecff744d1322a7b63"}, + {file = "coverage-7.4.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:40209e141059b9370a2657c9b15607815359ab3ef9918f0196b6fccce8d3230f"}, + {file = "coverage-7.4.4-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8a2b2b78c78293782fd3767d53e6474582f62443d0504b1554370bde86cc8227"}, + {file = "coverage-7.4.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:73bfb9c09951125d06ee473bed216e2c3742f530fc5acc1383883125de76d9cd"}, + {file = "coverage-7.4.4-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:1f384c3cc76aeedce208643697fb3e8437604b512255de6d18dae3f27655a384"}, + {file = "coverage-7.4.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:54eb8d1bf7cacfbf2a3186019bcf01d11c666bd495ed18717162f7eb1e9dd00b"}, + {file = "coverage-7.4.4-cp311-cp311-win32.whl", hash = "sha256:cac99918c7bba15302a2d81f0312c08054a3359eaa1929c7e4b26ebe41e9b286"}, + {file = "coverage-7.4.4-cp311-cp311-win_amd64.whl", hash = "sha256:b14706df8b2de49869ae03a5ccbc211f4041750cd4a66f698df89d44f4bd30ec"}, + {file = "coverage-7.4.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:201bef2eea65e0e9c56343115ba3814e896afe6d36ffd37bab783261db430f76"}, + {file = "coverage-7.4.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:41c9c5f3de16b903b610d09650e5e27adbfa7f500302718c9ffd1c12cf9d6818"}, + {file = "coverage-7.4.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d898fe162d26929b5960e4e138651f7427048e72c853607f2b200909794ed978"}, + {file = "coverage-7.4.4-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3ea79bb50e805cd6ac058dfa3b5c8f6c040cb87fe83de10845857f5535d1db70"}, + {file = "coverage-7.4.4-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ce4b94265ca988c3f8e479e741693d143026632672e3ff924f25fab50518dd51"}, + {file = "coverage-7.4.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:00838a35b882694afda09f85e469c96367daa3f3f2b097d846a7216993d37f4c"}, + {file = "coverage-7.4.4-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:fdfafb32984684eb03c2d83e1e51f64f0906b11e64482df3c5db936ce3839d48"}, + {file = "coverage-7.4.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:69eb372f7e2ece89f14751fbcbe470295d73ed41ecd37ca36ed2eb47512a6ab9"}, + {file = "coverage-7.4.4-cp312-cp312-win32.whl", hash = "sha256:137eb07173141545e07403cca94ab625cc1cc6bc4c1e97b6e3846270e7e1fea0"}, + {file = "coverage-7.4.4-cp312-cp312-win_amd64.whl", hash = "sha256:d71eec7d83298f1af3326ce0ff1d0ea83c7cb98f72b577097f9083b20bdaf05e"}, + {file = "coverage-7.4.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:d5ae728ff3b5401cc320d792866987e7e7e880e6ebd24433b70a33b643bb0384"}, + {file = "coverage-7.4.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:cc4f1358cb0c78edef3ed237ef2c86056206bb8d9140e73b6b89fbcfcbdd40e1"}, + {file = "coverage-7.4.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8130a2aa2acb8788e0b56938786c33c7c98562697bf9f4c7d6e8e5e3a0501e4a"}, + {file = "coverage-7.4.4-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cf271892d13e43bc2b51e6908ec9a6a5094a4df1d8af0bfc360088ee6c684409"}, + {file = "coverage-7.4.4-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a4cdc86d54b5da0df6d3d3a2f0b710949286094c3a6700c21e9015932b81447e"}, + {file = "coverage-7.4.4-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:ae71e7ddb7a413dd60052e90528f2f65270aad4b509563af6d03d53e979feafd"}, + {file = "coverage-7.4.4-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:38dd60d7bf242c4ed5b38e094baf6401faa114fc09e9e6632374388a404f98e7"}, + {file = "coverage-7.4.4-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:aa5b1c1bfc28384f1f53b69a023d789f72b2e0ab1b3787aae16992a7ca21056c"}, + {file = "coverage-7.4.4-cp38-cp38-win32.whl", hash = "sha256:dfa8fe35a0bb90382837b238fff375de15f0dcdb9ae68ff85f7a63649c98527e"}, + {file = "coverage-7.4.4-cp38-cp38-win_amd64.whl", hash = "sha256:b2991665420a803495e0b90a79233c1433d6ed77ef282e8e152a324bbbc5e0c8"}, + {file = "coverage-7.4.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:3b799445b9f7ee8bf299cfaed6f5b226c0037b74886a4e11515e569b36fe310d"}, + {file = "coverage-7.4.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:b4d33f418f46362995f1e9d4f3a35a1b6322cb959c31d88ae56b0298e1c22357"}, + {file = "coverage-7.4.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:aadacf9a2f407a4688d700e4ebab33a7e2e408f2ca04dbf4aef17585389eff3e"}, + {file = "coverage-7.4.4-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7c95949560050d04d46b919301826525597f07b33beba6187d04fa64d47ac82e"}, + {file = "coverage-7.4.4-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ff7687ca3d7028d8a5f0ebae95a6e4827c5616b31a4ee1192bdfde697db110d4"}, + {file = "coverage-7.4.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:5fc1de20b2d4a061b3df27ab9b7c7111e9a710f10dc2b84d33a4ab25065994ec"}, + {file = "coverage-7.4.4-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:c74880fc64d4958159fbd537a091d2a585448a8f8508bf248d72112723974cbd"}, + {file = "coverage-7.4.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:742a76a12aa45b44d236815d282b03cfb1de3b4323f3e4ec933acfae08e54ade"}, + {file = "coverage-7.4.4-cp39-cp39-win32.whl", hash = "sha256:d89d7b2974cae412400e88f35d86af72208e1ede1a541954af5d944a8ba46c57"}, + {file = "coverage-7.4.4-cp39-cp39-win_amd64.whl", hash = "sha256:9ca28a302acb19b6af89e90f33ee3e1906961f94b54ea37de6737b7ca9d8827c"}, + {file = "coverage-7.4.4-pp38.pp39.pp310-none-any.whl", hash = "sha256:b2c5edc4ac10a7ef6605a966c58929ec6c1bd0917fb8c15cb3363f65aa40e677"}, + {file = "coverage-7.4.4.tar.gz", hash = "sha256:c901df83d097649e257e803be22592aedfd5182f07b3cc87d640bbb9afd50f49"}, +] + +[package.dependencies] +tomli = {version = "*", optional = true, markers = "python_full_version <= \"3.11.0a6\" and extra == \"toml\""} + +[package.extras] +toml = ["tomli"] + [[package]] name = "distro" version = "1.9.0" @@ -714,6 +781,24 @@ pytest = ">=7.0.0" docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"] testing = ["coverage (>=6.2)", "flaky (>=3.5.0)", "hypothesis (>=5.7.1)", "mypy (>=0.931)", "pytest-trio (>=0.7.0)"] +[[package]] +name = "pytest-cov" +version = "4.1.0" +description = "Pytest plugin for measuring coverage." +optional = false +python-versions = ">=3.7" +files = [ + {file = "pytest-cov-4.1.0.tar.gz", hash = "sha256:3904b13dfbfec47f003b8e77fd5b589cd11904a21ddf1ab38a64f204d6a10ef6"}, + {file = "pytest_cov-4.1.0-py3-none-any.whl", hash = "sha256:6ba70b9e97e69fcc3fb45bfeab2d0a138fb65c4d0d6a41ef33983ad114be8c3a"}, +] + +[package.dependencies] +coverage = {version = ">=5.2.1", extras = ["toml"]} +pytest = ">=4.6" + +[package.extras] +testing = ["fields", "hunter", "process-tests", "pytest-xdist", "six", "virtualenv"] + [[package]] name = "pytest-mock" version = "3.12.0" @@ -1185,4 +1270,4 @@ watchmedo = ["PyYAML (>=3.10)"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "3bac9595d36b9283144eda60bd3bcca227d573030f546fbc84fb99dd7419b603" +content-hash = "93b724f0c34c84f376c9607afc14059fc603f6c0c1b5fa4c153c5fce9cb10e63" diff --git a/libs/partners/openai/pyproject.toml b/libs/partners/openai/pyproject.toml index b89fea0232c70..8492954bbeeee 100644 --- a/libs/partners/openai/pyproject.toml +++ b/libs/partners/openai/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "langchain-openai" -version = "0.1.0" +version = "0.1.1" description = "An integration package connecting OpenAI and LangChain" authors = [] readme = "README.md" @@ -27,6 +27,7 @@ syrupy = "^4.0.2" pytest-watcher = "^0.3.4" pytest-asyncio = "^0.21.1" langchain-core = { path = "../../core", develop = true } +pytest-cov = "^4.1.0" [tool.poetry.group.codespell] optional = true @@ -89,7 +90,7 @@ build-backend = "poetry.core.masonry.api" # # https://github.com/tophat/syrupy # --snapshot-warn-unused Prints a warning on unused snapshots rather than fail the test suite. -addopts = "--snapshot-warn-unused --strict-markers --strict-config --durations=5" +addopts = "--snapshot-warn-unused --strict-markers --strict-config --durations=5 --cov=langchain_openai" # Registering custom markers. # https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers markers = [ diff --git a/libs/partners/openai/tests/integration_tests/chat_models/test_azure.py b/libs/partners/openai/tests/integration_tests/chat_models/test_azure.py index 24a5e26ee79f7..001b5296e3159 100644 --- a/libs/partners/openai/tests/integration_tests/chat_models/test_azure.py +++ b/libs/partners/openai/tests/integration_tests/chat_models/test_azure.py @@ -1,4 +1,5 @@ """Test AzureChatOpenAI wrapper.""" + import os from typing import Any, Optional @@ -6,6 +7,7 @@ from langchain_core.callbacks import CallbackManager from langchain_core.messages import BaseMessage, BaseMessageChunk, HumanMessage from langchain_core.outputs import ChatGeneration, ChatResult, LLMResult +from langchain_core.pydantic_v1 import BaseModel from langchain_openai import AzureChatOpenAI from tests.unit_tests.fake.callbacks import FakeCallbackHandler @@ -223,3 +225,18 @@ def test_openai_invoke(llm: AzureChatOpenAI) -> None: result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"])) assert isinstance(result.content, str) + + +@pytest.mark.skip(reason="Need tool calling model deployed on azure") +def test_openai_structured_output(llm: AzureChatOpenAI) -> None: + class MyModel(BaseModel): + """A Person""" + + name: str + age: int + + llm_structure = llm.with_structured_output(MyModel) + result = llm_structure.invoke("I'm a 27 year old named Erick") + assert isinstance(result, MyModel) + assert result.name == "Erick" + assert result.age == 27 diff --git a/libs/partners/openai/tests/integration_tests/chat_models/test_base.py b/libs/partners/openai/tests/integration_tests/chat_models/test_base.py index e40ab2e654ab7..50baef64d1465 100644 --- a/libs/partners/openai/tests/integration_tests/chat_models/test_base.py +++ b/libs/partners/openai/tests/integration_tests/chat_models/test_base.py @@ -1,5 +1,5 @@ """Test ChatOpenAI chat model.""" -from typing import Any, Optional, cast +from typing import Any, List, Optional, cast import pytest from langchain_core.callbacks import CallbackManager @@ -9,6 +9,7 @@ BaseMessageChunk, HumanMessage, SystemMessage, + ToolMessage, ) from langchain_core.outputs import ( ChatGeneration, @@ -467,3 +468,36 @@ async def test_async_response_metadata_streaming() -> None: ) ) assert "content" in cast(BaseMessageChunk, full).response_metadata["logprobs"] + + +class GenerateUsername(BaseModel): + "Get a username based on someone's name and hair color." + + name: str + hair_color: str + + +def test_tool_use() -> None: + llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0) + llm_with_tool = llm.bind_tools(tools=[GenerateUsername], tool_choice=True) + msgs: List = [HumanMessage("Sally has green hair, what would her username be?")] + ai_msg = llm_with_tool.invoke(msgs) + tool_msg = ToolMessage( + "sally_green_hair", tool_call_id=ai_msg.additional_kwargs["tool_calls"][0]["id"] + ) + msgs.extend([ai_msg, tool_msg]) + llm_with_tool.invoke(msgs) + + +def test_openai_structured_output() -> None: + class MyModel(BaseModel): + """A Person""" + + name: str + age: int + + llm = ChatOpenAI().with_structured_output(MyModel) + result = llm.invoke("I'm a 27 year old named Erick") + assert isinstance(result, MyModel) + assert result.name == "Erick" + assert result.age == 27 diff --git a/libs/partners/openai/tests/unit_tests/chat_models/test_base.py b/libs/partners/openai/tests/unit_tests/chat_models/test_base.py index 87e7111959ee8..4a9a64980571e 100644 --- a/libs/partners/openai/tests/unit_tests/chat_models/test_base.py +++ b/libs/partners/openai/tests/unit_tests/chat_models/test_base.py @@ -10,10 +10,14 @@ FunctionMessage, HumanMessage, SystemMessage, + ToolMessage, ) from langchain_openai import ChatOpenAI -from langchain_openai.chat_models.base import _convert_dict_to_message +from langchain_openai.chat_models.base import ( + _convert_dict_to_message, + _convert_message_to_dict, +) def test_openai_model_param() -> None: @@ -43,6 +47,7 @@ def test__convert_dict_to_message_human() -> None: result = _convert_dict_to_message(message) expected_output = HumanMessage(content="foo") assert result == expected_output + assert _convert_message_to_dict(expected_output) == message def test__convert_dict_to_message_human_with_name() -> None: @@ -50,6 +55,7 @@ def test__convert_dict_to_message_human_with_name() -> None: result = _convert_dict_to_message(message) expected_output = HumanMessage(content="foo", name="test") assert result == expected_output + assert _convert_message_to_dict(expected_output) == message def test__convert_dict_to_message_ai() -> None: @@ -57,6 +63,7 @@ def test__convert_dict_to_message_ai() -> None: result = _convert_dict_to_message(message) expected_output = AIMessage(content="foo") assert result == expected_output + assert _convert_message_to_dict(expected_output) == message def test__convert_dict_to_message_ai_with_name() -> None: @@ -64,6 +71,7 @@ def test__convert_dict_to_message_ai_with_name() -> None: result = _convert_dict_to_message(message) expected_output = AIMessage(content="foo", name="test") assert result == expected_output + assert _convert_message_to_dict(expected_output) == message def test__convert_dict_to_message_system() -> None: @@ -71,6 +79,7 @@ def test__convert_dict_to_message_system() -> None: result = _convert_dict_to_message(message) expected_output = SystemMessage(content="foo") assert result == expected_output + assert _convert_message_to_dict(expected_output) == message def test__convert_dict_to_message_system_with_name() -> None: @@ -78,6 +87,15 @@ def test__convert_dict_to_message_system_with_name() -> None: result = _convert_dict_to_message(message) expected_output = SystemMessage(content="foo", name="test") assert result == expected_output + assert _convert_message_to_dict(expected_output) == message + + +def test__convert_dict_to_message_tool() -> None: + message = {"role": "tool", "content": "foo", "tool_call_id": "bar"} + result = _convert_dict_to_message(message) + expected_output = ToolMessage(content="foo", tool_call_id="bar") + assert result == expected_output + assert _convert_message_to_dict(expected_output) == message @pytest.fixture