From cbb00452b8b4dbc4cb52c0b720a73c799e2af163 Mon Sep 17 00:00:00 2001 From: Jerron Lim Date: Sat, 3 Aug 2024 19:18:48 +0800 Subject: [PATCH] Enable ollama models as agents (#98) --- README.md | 30 +++++ .../0a354b5c6f6c_create_writes_table.py | 7 +- ...bfce2_add_base_url_col_to_members_table.py | 29 +++++ .../versions/3a8a5f819c5f_add_thread_table.py | 4 +- .../6fa42be09dd2_add_checkpoints_table.py | 8 +- backend/app/core/graph/build.py | 13 ++- backend/app/core/graph/members.py | 34 +++++- backend/app/models.py | 1 + backend/poetry.lock | 103 ++++++++---------- backend/pyproject.toml | 10 +- frontend/src/client/models/MemberCreate.ts | 1 + frontend/src/client/models/MemberOut.ts | 1 + frontend/src/client/models/MemberUpdate.ts | 1 + frontend/src/client/schemas/$MemberCreate.ts | 8 ++ frontend/src/client/schemas/$MemberOut.ts | 8 ++ frontend/src/client/schemas/$MemberUpdate.ts | 8 ++ .../src/components/Members/EditMember.tsx | 92 +++++++++++++--- 17 files changed, 266 insertions(+), 92 deletions(-) create mode 100644 backend/app/alembic/versions/38a9c73bfce2_add_base_url_col_to_members_table.py diff --git a/README.md b/README.md index 35d1f68d..62d773f5 100644 --- a/README.md +++ b/README.md @@ -30,6 +30,10 @@ - [Writing a Custom Skill using LangChain](#writing-a-custom-skill-using-langchain) - [Retrieval Augmented Generation (RAG)](#retrieval-augmented-generation-rag) - [Customising embedding models](#customising-embedding-models) + - [Using Open Source Models](#using-open-source-models) + - [Using Open Source Models with Ollama](#using-open-source-models-with-ollama) + - [Choosing the Right Models](#choosing-the-right-models) + - [Using Open Source Models without Ollama](#using-open-source-models-without-ollama) - [Guides](#guides) - [Creating Your First Hierarchical Team](#creating-your-first-hierarchical-team) - [Equipping Your Team Member with Skills](#equipping-your-team-member-with-skills) @@ -61,6 +65,7 @@ and many many more! - **Tool Calling**: Enable your agents to utilize external tools and APIs. - **Retrieval Augmented Generation**: Enable your agents to reason with your internal knowledge base. - **Human-In-The-Loop**: Enable human approval before tool calling. +- **Open Source Models**: Use open-source LLM models such as llama, Gemma and Phi. - **Easy Deployment**: Deploy Tribe effortlessly using Docker. - **Multi-Tenancy**: Manage and support multiple users and teams. @@ -198,6 +203,31 @@ DENSE_EMBEDDING_MODEL=BAAI/bge-small-en-v1.5 # Change this > [!WARNING] > If your existing and new embedding models have different vector dimensions, you may need to recreate your Qdrant collection. You can delete the collection through the Qdrant Dashboard at [http://qdrant.localhost/dashboard](http://qdrant.localhost/dashboard). Therefore, it is better to plan ahead which embedding model is most suitable for your workflows. +### Using Open Source Models + +Open source models are becoming cheaper and easier to run, and some even match the performance of closed models. You might prefer using them for their privacy and cost benefits. If you are running Tribe locally and want to use open source models, I would recommend Ollama for its ease of use. + +#### Using Open Source Models with Ollama +1. **Install Ollama:** First, set up Ollama on your device. You can find the instructions in [Ollama's repo](https://github.com/ollama/ollama). +2. **Download Models:** Download your preferred models from Ollama +3. **Configure your agents:** + - Update the agent's provider to `ollama`. + - Paste the downloaded model's name (e.g., `llama3.1:8b`) into the model input field. + - By default, Tribe will run on `http://host.docker.internal:11434`, which maps to `https://localhost:11434`. This setup allows Tribe to communicate with the default Ollama host. If your setup uses a different host, specify the new host in the 'Base URL' input field. + +#### Choosing the Right Models +There are hundreds of open source models in [Ollama's library](https://ollama.com/library) suitable for different tasks. Here’s how to choose the right one for your use case: +- **Tool Calling Models:** If you are planning to equip agents with specific skills, use models like `Llama3.1`, `Mistral Nemo`, `Firefunction V2`, or `Command-R +` and others that support tool calling. +- **For Creative, Reasoning and other Tasks:** You have more flexibility. You may stick to tool calling capable models or consider models like `gemma2` or `phi3`. + +#### Using Open Source Models without Ollama + +If you’re not planning to use Ollama, you can still run open source models compatible with the [OpenAI chat completions API](https://platform.openai.com/docs/api-reference/introduction). + +Steps: +1. **Edit Your Agent:** Select 'OpenAI' as your model provider. +2. **Specify Endpoint:** Under 'Base URL', specify the model’s inference endpoint. + ### Guides #### Creating Your First Hierarchical Team diff --git a/backend/app/alembic/versions/0a354b5c6f6c_create_writes_table.py b/backend/app/alembic/versions/0a354b5c6f6c_create_writes_table.py index adbc8639..4708467f 100644 --- a/backend/app/alembic/versions/0a354b5c6f6c_create_writes_table.py +++ b/backend/app/alembic/versions/0a354b5c6f6c_create_writes_table.py @@ -8,6 +8,7 @@ from alembic import op import sqlalchemy as sa import sqlmodel.sql.sqltypes +from sqlalchemy.dialects.postgresql import UUID # revision identifiers, used by Alembic. @@ -20,9 +21,9 @@ def upgrade(): # ### commands auto generated by Alembic - please adjust! ### op.create_table('writes', - sa.Column('thread_id', sqlmodel.sql.sqltypes.GUID(), nullable=False), - sa.Column('thread_ts', sqlmodel.sql.sqltypes.GUID(), nullable=False), - sa.Column('task_id', sqlmodel.sql.sqltypes.GUID(), nullable=False), + sa.Column('thread_id', UUID(as_uuid=True), nullable=False), + sa.Column('thread_ts', UUID(as_uuid=True), nullable=False), + sa.Column('task_id', UUID(as_uuid=True), nullable=False), sa.Column('idx', sa.Integer(), nullable=False), sa.Column('channel', sqlmodel.sql.sqltypes.AutoString(), nullable=False), sa.Column('value', sa.LargeBinary(), nullable=False), diff --git a/backend/app/alembic/versions/38a9c73bfce2_add_base_url_col_to_members_table.py b/backend/app/alembic/versions/38a9c73bfce2_add_base_url_col_to_members_table.py new file mode 100644 index 00000000..e1487b27 --- /dev/null +++ b/backend/app/alembic/versions/38a9c73bfce2_add_base_url_col_to_members_table.py @@ -0,0 +1,29 @@ +"""add base_url col to members table + +Revision ID: 38a9c73bfce2 +Revises: 6e7c33ddf30f +Create Date: 2024-07-29 15:18:15.979804 + +""" +from alembic import op +import sqlalchemy as sa +import sqlmodel.sql.sqltypes + + +# revision identifiers, used by Alembic. +revision = '38a9c73bfce2' +down_revision = '6e7c33ddf30f' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('member', sa.Column('base_url', sqlmodel.sql.sqltypes.AutoString(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('member', 'base_url') + # ### end Alembic commands ### \ No newline at end of file diff --git a/backend/app/alembic/versions/3a8a5f819c5f_add_thread_table.py b/backend/app/alembic/versions/3a8a5f819c5f_add_thread_table.py index 9794c5aa..597493bc 100644 --- a/backend/app/alembic/versions/3a8a5f819c5f_add_thread_table.py +++ b/backend/app/alembic/versions/3a8a5f819c5f_add_thread_table.py @@ -8,6 +8,8 @@ from alembic import op import sqlalchemy as sa import sqlmodel.sql.sqltypes +from sqlalchemy.dialects.postgresql import UUID + # revision identifiers, used by Alembic. @@ -21,7 +23,7 @@ def upgrade(): # ### commands auto generated by Alembic - please adjust! ### op.create_table('thread', sa.Column('query', sqlmodel.sql.sqltypes.AutoString(), nullable=False), - sa.Column('id', sqlmodel.sql.sqltypes.GUID(), nullable=False), + sa.Column('id', UUID(as_uuid=True), nullable=False), sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False), sa.Column('team_id', sa.Integer(), nullable=False), sa.ForeignKeyConstraint(['team_id'], ['team.id'], ), diff --git a/backend/app/alembic/versions/6fa42be09dd2_add_checkpoints_table.py b/backend/app/alembic/versions/6fa42be09dd2_add_checkpoints_table.py index 9c5d8a45..04811e5a 100644 --- a/backend/app/alembic/versions/6fa42be09dd2_add_checkpoints_table.py +++ b/backend/app/alembic/versions/6fa42be09dd2_add_checkpoints_table.py @@ -7,7 +7,7 @@ """ from alembic import op import sqlalchemy as sa -import sqlmodel.sql.sqltypes +from sqlalchemy.dialects.postgresql import UUID # revision identifiers, used by Alembic. @@ -20,9 +20,9 @@ def upgrade(): # ### commands auto generated by Alembic - please adjust! ### op.create_table('checkpoints', - sa.Column('thread_id', sqlmodel.sql.sqltypes.GUID(), nullable=False), - sa.Column('thread_ts', sqlmodel.sql.sqltypes.GUID(), nullable=False), - sa.Column('parent_ts', sqlmodel.sql.sqltypes.GUID(), nullable=True), + sa.Column('thread_id', UUID(as_uuid=True), nullable=False), + sa.Column('thread_ts', UUID(as_uuid=True), nullable=False), + sa.Column('parent_ts', UUID(as_uuid=True), nullable=True), sa.Column('checkpoint', sa.LargeBinary(), nullable=False), sa.Column('metadata', sa.LargeBinary(), nullable=False), sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False), diff --git a/backend/app/core/graph/build.py b/backend/app/core/graph/build.py index bcf53e85..e8e18a95 100644 --- a/backend/app/core/graph/build.py +++ b/backend/app/core/graph/build.py @@ -88,6 +88,7 @@ def convert_hierarchical_team_to_dict( teams[leader_name] = GraphTeam( name=leader_name, model=member.model, + base_url=member.base_url, role=member.role, backstory=member.backstory or "", members={}, @@ -126,6 +127,7 @@ def convert_hierarchical_team_to_dict( tools=tools, provider=member.provider, model=member.model, + base_url=member.base_url, temperature=member.temperature, interrupt=member.interrupt, ) @@ -136,6 +138,7 @@ def convert_hierarchical_team_to_dict( role=member.role, provider=member.provider, model=member.model, + base_url=member.base_url, temperature=member.temperature, ) for nei_id in out_counts[member_id]: @@ -196,6 +199,7 @@ def convert_sequential_team_to_dict(members: list[Member]) -> Mapping[str, Graph tools=tools, provider=memberModel.provider, model=memberModel.model, + base_url=memberModel.base_url, temperature=memberModel.temperature, interrupt=memberModel.interrupt, ) @@ -288,6 +292,7 @@ def create_hierarchical_graph( LeaderNode( teams[leader_name].provider, teams[leader_name].model, + teams[leader_name].base_url, teams[leader_name].temperature, ).delegate # type: ignore[arg-type] ), @@ -298,6 +303,7 @@ def create_hierarchical_graph( SummariserNode( teams[leader_name].provider, teams[leader_name].model, + teams[leader_name].base_url, teams[leader_name].temperature, ).summarise # type: ignore[arg-type] ), @@ -312,6 +318,7 @@ def create_hierarchical_graph( WorkerNode( member.provider, member.model, + member.base_url, member.temperature, ).work # type: ignore[arg-type] ), @@ -385,7 +392,10 @@ def create_sequential_graph( member.name, RunnableLambda( SequentialWorkerNode( - member.provider, member.model, member.temperature + member.provider, + member.model, + member.base_url, + member.temperature, ).work # type: ignore[arg-type] ), ) @@ -489,6 +499,7 @@ async def generator( members=member_dict, # type: ignore[arg-type] provider=first_member.provider, model=first_member.model, + base_url=first_member.base_url, temperature=first_member.temperature, ), "messages": [], diff --git a/backend/app/core/graph/members.py b/backend/app/core/graph/members.py index c5143d1d..0334973e 100644 --- a/backend/app/core/graph/members.py +++ b/backend/app/core/graph/members.py @@ -12,6 +12,7 @@ RunnableSerializable, ) from langchain_core.tools import BaseTool +from langchain_ollama import ChatOllama from langchain_openai import ChatOpenAI from langgraph.graph import add_messages from pydantic import BaseModel, Field @@ -58,6 +59,10 @@ class GraphPerson(BaseModel): role: str = Field(description="Role of the person") provider: str = Field(description="The provider for the llm model") model: str = Field(description="The llm model to use for this person") + base_url: str | None = Field( + default=None, + description="Use a proxy to serve llm model", + ) temperature: float = Field(description="The temperature of the llm model") backstory: str = Field( description="Description of the person's experience, motives and concerns." @@ -94,6 +99,9 @@ class GraphTeam(BaseModel): ) provider: str = Field(description="The provider of the team leader's llm model") model: str = Field(description="The llm model to use for this team leader") + base_url: str | None = Field( + default=None, description="Use a proxy to serve llm model" + ) temperature: float = Field( description="The temperature of the team leader's llm model" ) @@ -146,10 +154,28 @@ class ReturnTeamState(TypedDict): class BaseNode: - def __init__(self, provider: str, model: str, temperature: float): - self.model = init_chat_model( - model, model_provider=provider, temperature=temperature, streaming=True - ) + def __init__( + self, provider: str, model: str, base_url: str | None, temperature: float + ): + # If using proxy, then we need to pass base url + # TODO: Include ollama here once langchain-ollama bug is fixed + if provider in ["openai"] and base_url: + self.model = init_chat_model( + model, + model_provider=provider, + temperature=temperature, + base_url=base_url, + ) + elif provider == "ollama": + self.model = ChatOllama( + model=model, + temperature=temperature, + base_url=base_url if base_url else "http://host.docker.internal:11434", + ) + else: + self.model = init_chat_model( + model, model_provider=provider, temperature=0, streaming=True + ) self.final_answer_model = init_chat_model( model, model_provider=provider, temperature=0, streaming=True ) diff --git a/backend/app/models.py b/backend/app/models.py index 5a66f04f..cc0a5e60 100644 --- a/backend/app/models.py +++ b/backend/app/models.py @@ -256,6 +256,7 @@ class MemberBase(SQLModel): model: str = "gpt-4o-mini" temperature: float = 0.7 interrupt: bool = False + base_url: str | None = None class MemberCreate(MemberBase): diff --git a/backend/poetry.lock b/backend/poetry.lock index 0c663224..f25c11bc 100644 --- a/backend/poetry.lock +++ b/backend/poetry.lock @@ -1816,13 +1816,13 @@ test = ["Cython (>=0.29.24,<0.30.0)"] [[package]] name = "httpx" -version = "0.25.2" +version = "0.27.0" description = "The next generation HTTP client." optional = false python-versions = ">=3.8" files = [ - {file = "httpx-0.25.2-py3-none-any.whl", hash = "sha256:a05d3d052d9b2dfce0e3896636467f8a5342fb2b902c819428e1ac65413ca118"}, - {file = "httpx-0.25.2.tar.gz", hash = "sha256:8b8fcaa0c8ea7b05edd69a094e63a2094c4efcb48129fb757361bc423c0ad9e8"}, + {file = "httpx-0.27.0-py3-none-any.whl", hash = "sha256:71d5465162c13681bff01ad59b2cc68dd838ea1f10e51574bac27103f00c91a5"}, + {file = "httpx-0.27.0.tar.gz", hash = "sha256:a0cb88a46f32dc874e04ee956e4c2764aba2aa228f650b06788ba6bda2962ab5"}, ] [package.dependencies] @@ -2102,19 +2102,19 @@ zookeeper = ["kazoo (>=2.8.0)"] [[package]] name = "langchain" -version = "0.2.11" +version = "0.2.12" description = "Building applications with LLMs through composability" optional = false python-versions = "<4.0,>=3.8.1" files = [ - {file = "langchain-0.2.11-py3-none-any.whl", hash = "sha256:5a7a8b4918f3d3bebce9b4f23b92d050699e6f7fb97591e8941177cf07a260a2"}, - {file = "langchain-0.2.11.tar.gz", hash = "sha256:d7a9e4165f02dca0bd78addbc2319d5b9286b5d37c51d784124102b57e9fd297"}, + {file = "langchain-0.2.12-py3-none-any.whl", hash = "sha256:565d2f5df1c06815d1c684400218ec4ae5e1027887aad343226fad846c54e726"}, + {file = "langchain-0.2.12.tar.gz", hash = "sha256:fe7bd409c133017446fec54c38a5e7cb14f74e020090d7b5065374badf71e6d1"}, ] [package.dependencies] aiohttp = ">=3.8.3,<4.0.0" async-timeout = {version = ">=4.0.0,<5.0.0", markers = "python_version < \"3.11\""} -langchain-core = ">=0.2.23,<0.3.0" +langchain-core = ">=0.2.27,<0.3.0" langchain-text-splitters = ">=0.2.0,<0.3.0" langsmith = ">=0.1.17,<0.2.0" numpy = [ @@ -2186,13 +2186,13 @@ tenacity = ">=8.1.0,<8.4.0 || >8.4.0,<9.0.0" [[package]] name = "langchain-core" -version = "0.2.24" +version = "0.2.28" description = "Building applications with LLMs through composability" optional = false python-versions = "<4.0,>=3.8.1" files = [ - {file = "langchain_core-0.2.24-py3-none-any.whl", hash = "sha256:9444fc082d21ef075d925590a684a73fe1f9688a3d90087580ec929751be55e7"}, - {file = "langchain_core-0.2.24.tar.gz", hash = "sha256:f2e3fa200b124e8c45d270da9bf836bed9c09532612c96ff3225e59b9a232f5a"}, + {file = "langchain_core-0.2.28-py3-none-any.whl", hash = "sha256:0728761d02ce696a1c6a57cfad18b874cf6c9566ba86120e2f542e442cb77a06"}, + {file = "langchain_core-0.2.28.tar.gz", hash = "sha256:589f907fcb1f15acea55ea3f451a37faaa61c2e68b3d39d436cf73ca3dd23ef5"}, ] [package.dependencies] @@ -2205,6 +2205,7 @@ pydantic = [ ] PyYAML = ">=5.3" tenacity = ">=8.1.0,<8.4.0 || >8.4.0,<9.0.0" +typing-extensions = ">=4.7" [[package]] name = "langchain-google-genai" @@ -2224,6 +2225,21 @@ langchain-core = ">=0.2.0,<0.3" [package.extras] images = ["pillow (>=10.1.0,<11.0.0)"] +[[package]] +name = "langchain-ollama" +version = "0.1.1" +description = "An integration package connecting Ollama and LangChain" +optional = false +python-versions = "<4.0,>=3.8.1" +files = [ + {file = "langchain_ollama-0.1.1-py3-none-any.whl", hash = "sha256:179b6f21e01fc72ebc034ec725f8c5dcef4a81709919278e6fa4f43605df5d82"}, + {file = "langchain_ollama-0.1.1.tar.gz", hash = "sha256:91b3b6cfcc90890c683995520d84210ebd2cee8c0f2cd0a5ffde9f1ffbee2f94"}, +] + +[package.dependencies] +langchain-core = ">=0.2.20,<0.3.0" +ollama = ">=0.3.0,<1" + [[package]] name = "langchain-openai" version = "0.1.17" @@ -2286,30 +2302,6 @@ files = [ [package.dependencies] langchain-core = ">=0.2.19,<0.3" -[[package]] -name = "langserve" -version = "0.0.51" -description = "" -optional = false -python-versions = ">=3.8.1,<4.0.0" -files = [ - {file = "langserve-0.0.51-py3-none-any.whl", hash = "sha256:e735eef2b6fde7e1514f4be8234b9f0727283e639822ca9c25e8ccc2d24e8492"}, - {file = "langserve-0.0.51.tar.gz", hash = "sha256:036c0104c512bcc2c2406ae089ef9e7e718c32c39ebf6dcb2212f168c7d09816"}, -] - -[package.dependencies] -fastapi = {version = ">=0.90.1,<1", optional = true, markers = "extra == \"server\" or extra == \"all\""} -httpx = ">=0.23.0" -langchain = ">=0.0.333" -orjson = ">=2" -pydantic = ">=1" -sse-starlette = {version = ">=1.3.0,<2.0.0", optional = true, markers = "extra == \"server\" or extra == \"all\""} - -[package.extras] -all = ["fastapi (>=0.90.1,<1)", "httpx-sse (>=0.3.1)", "sse-starlette (>=1.3.0,<2.0.0)"] -client = ["httpx-sse (>=0.3.1)"] -server = ["fastapi (>=0.90.1,<1)", "sse-starlette (>=1.3.0,<2.0.0)"] - [[package]] name = "langsmith" version = "0.1.82" @@ -2961,6 +2953,20 @@ files = [ {file = "numpy-1.26.4.tar.gz", hash = "sha256:2a02aba9ed12e4ac4eb3ea9421c420301a0c6460d9830d74a9df87efa4912010"}, ] +[[package]] +name = "ollama" +version = "0.3.0" +description = "The official Python client for Ollama." +optional = false +python-versions = "<4.0,>=3.8" +files = [ + {file = "ollama-0.3.0-py3-none-any.whl", hash = "sha256:cd7010c4e2a37d7f08f36cd35c4592b14f1ec0d1bf3df10342cd47963d81ad7a"}, + {file = "ollama-0.3.0.tar.gz", hash = "sha256:6ff493a2945ba76cdd6b7912a1cd79a45cfd9ba9120d14adeb63b2b5a7f353da"}, +] + +[package.dependencies] +httpx = ">=0.27.0,<0.28.0" + [[package]] name = "onnx" version = "1.16.1" @@ -4471,35 +4477,18 @@ sqlcipher = ["sqlcipher3_binary"] [[package]] name = "sqlmodel" -version = "0.0.16" +version = "0.0.21" description = "SQLModel, SQL databases in Python, designed for simplicity, compatibility, and robustness." optional = false -python-versions = ">=3.7,<4.0" +python-versions = ">=3.7" files = [ - {file = "sqlmodel-0.0.16-py3-none-any.whl", hash = "sha256:b972f5d319580d6c37ecc417881f6ec4d1ad3ed3583d0ac0ed43234a28bf605a"}, - {file = "sqlmodel-0.0.16.tar.gz", hash = "sha256:966656f18a8e9a2d159eb215b07fb0cf5222acfae3362707ca611848a8a06bd1"}, + {file = "sqlmodel-0.0.21-py3-none-any.whl", hash = "sha256:eca104afe8a643f0764076b29f02e51d19d6b35c458f4c119942960362a4b52a"}, + {file = "sqlmodel-0.0.21.tar.gz", hash = "sha256:b2034c23d930f66d2091b17a4280a9c23a7ea540a71e7fcf9c746d262f06f74a"}, ] [package.dependencies] pydantic = ">=1.10.13,<3.0.0" -SQLAlchemy = ">=2.0.0,<2.1.0" - -[[package]] -name = "sse-starlette" -version = "1.8.2" -description = "SSE plugin for Starlette" -optional = false -python-versions = ">=3.8" -files = [ - {file = "sse_starlette-1.8.2-py3-none-any.whl", hash = "sha256:70cc7ef5aca4abe8a25dec1284cce4fe644dd7bf0c406d3e852e516092b7f849"}, - {file = "sse_starlette-1.8.2.tar.gz", hash = "sha256:e0f9b8dec41adc092a0a6e0694334bd3cfd3084c44c497a6ebc1fb4bdd919acd"}, -] - -[package.dependencies] -anyio = "*" -fastapi = "*" -starlette = "*" -uvicorn = "*" +SQLAlchemy = ">=2.0.14,<2.1.0" [[package]] name = "starlette" @@ -5325,4 +5314,4 @@ repair = ["scipy (>=1.6.3)"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "0f6ffd87c9bfece9f79bf90483af19cb7796e389939419c35dca49347388a10b" +content-hash = "7335e20fadc21f74bfb3d9fa6462eba11fcbcc33fe73f2d176c58de85ea5f8ab" diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 147be0d1..788a1e5d 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -18,18 +18,17 @@ emails = "^0.6" gunicorn = "^22.0.0" jinja2 = "^3.1.4" alembic = "^1.12.1" -httpx = "^0.25.1" +httpx = "0.27.0" psycopg = {extras = ["binary"], version = "^3.1.13"} -sqlmodel = "^0.0.16" +sqlmodel = "0.0.21" # Pin bcrypt until passlib supports the latest bcrypt = "4.0.1" pydantic-settings = "^2.2.1" sentry-sdk = {extras = ["fastapi"], version = "^2.8.0"} langgraph = "0.1.9" -langserve = {extras = ["server"], version = "^0.0.51"} langchain-openai = "0.1.17" grandalf = "^0.8" -langchain = "0.2.11" +langchain = "0.2.12" langchain-community = "0.2.9" duckduckgo-search = "6.1.0" wikipedia = "^1.4.0" @@ -38,7 +37,7 @@ langchain-cohere = "^0.1.4" langchain-google-genai = "^1.0.2" google-search-results = "^2.4.2" yfinance = "^0.2.38" -langchain-core = "0.2.24" +langchain-core = "0.2.28" pyjwt = "^2.8.0" psycopg2 = "^2.9.9" asyncpg = "^0.29.0" @@ -53,6 +52,7 @@ redis = "^5.0.7" celery-stubs = "^0.1.3" pymupdf = "^1.24.7" psycopg-pool = "^3.2.2" +langchain-ollama = "0.1.1" [tool.poetry.group.dev.dependencies] pytest = "^7.4.3" diff --git a/frontend/src/client/models/MemberCreate.ts b/frontend/src/client/models/MemberCreate.ts index 191a3a59..0dd2053d 100644 --- a/frontend/src/client/models/MemberCreate.ts +++ b/frontend/src/client/models/MemberCreate.ts @@ -16,5 +16,6 @@ export type MemberCreate = { model?: string; temperature?: number; interrupt?: boolean; + base_url?: (string | null); }; diff --git a/frontend/src/client/models/MemberOut.ts b/frontend/src/client/models/MemberOut.ts index d335e346..7b3654a3 100644 --- a/frontend/src/client/models/MemberOut.ts +++ b/frontend/src/client/models/MemberOut.ts @@ -19,6 +19,7 @@ export type MemberOut = { model?: string; temperature?: number; interrupt?: boolean; + base_url?: (string | null); id: number; belongs_to: number; skills: Array; diff --git a/frontend/src/client/models/MemberUpdate.ts b/frontend/src/client/models/MemberUpdate.ts index 9e217b0a..a9511b3b 100644 --- a/frontend/src/client/models/MemberUpdate.ts +++ b/frontend/src/client/models/MemberUpdate.ts @@ -19,6 +19,7 @@ export type MemberUpdate = { model?: (string | null); temperature?: (number | null); interrupt?: (boolean | null); + base_url?: (string | null); belongs_to?: (number | null); skills?: (Array | null); uploads?: (Array | null); diff --git a/frontend/src/client/schemas/$MemberCreate.ts b/frontend/src/client/schemas/$MemberCreate.ts index 934d0ed8..9c1b7d9f 100644 --- a/frontend/src/client/schemas/$MemberCreate.ts +++ b/frontend/src/client/schemas/$MemberCreate.ts @@ -61,5 +61,13 @@ export const $MemberCreate = { interrupt: { type: 'boolean', }, + base_url: { + type: 'any-of', + contains: [{ + type: 'string', + }, { + type: 'null', + }], + }, }, } as const; diff --git a/frontend/src/client/schemas/$MemberOut.ts b/frontend/src/client/schemas/$MemberOut.ts index 453a1d77..05980b11 100644 --- a/frontend/src/client/schemas/$MemberOut.ts +++ b/frontend/src/client/schemas/$MemberOut.ts @@ -62,6 +62,14 @@ export const $MemberOut = { interrupt: { type: 'boolean', }, + base_url: { + type: 'any-of', + contains: [{ + type: 'string', + }, { + type: 'null', + }], + }, id: { type: 'number', isRequired: true, diff --git a/frontend/src/client/schemas/$MemberUpdate.ts b/frontend/src/client/schemas/$MemberUpdate.ts index c8bbc7a2..30e48ea8 100644 --- a/frontend/src/client/schemas/$MemberUpdate.ts +++ b/frontend/src/client/schemas/$MemberUpdate.ts @@ -101,6 +101,14 @@ export const $MemberUpdate = { type: 'null', }], }, + base_url: { + type: 'any-of', + contains: [{ + type: 'string', + }, { + type: 'null', + }], + }, belongs_to: { type: 'any-of', contains: [{ diff --git a/frontend/src/components/Members/EditMember.tsx b/frontend/src/components/Members/EditMember.tsx index b42f6a1f..1aa958a9 100644 --- a/frontend/src/components/Members/EditMember.tsx +++ b/frontend/src/components/Members/EditMember.tsx @@ -3,6 +3,7 @@ import { Checkbox, FormControl, FormErrorMessage, + FormHelperText, FormLabel, Input, Modal, @@ -32,7 +33,12 @@ import { UploadsService, } from "../../client" import { type SubmitHandler, useForm, Controller } from "react-hook-form" -import { Select as MultiSelect, chakraComponents } from "chakra-react-select" +import { + Select as MultiSelect, + chakraComponents, + CreatableSelect, + type OptionBase, +} from "chakra-react-select" import { useState } from "react" interface EditMemberProps { @@ -42,6 +48,11 @@ interface EditMemberProps { onClose: () => void } +interface ModelOption extends OptionBase { + label: string + value: string +} + const customSelectOption = { Option: (props: any) => ( @@ -54,10 +65,11 @@ const customSelectOption = { const AVAILABLE_MODELS = { openai: ["gpt-4o-mini", "gpt-4o", "gpt-4-turbo", "gpt-3.5-turbo"], anthropic: [ - "claude-3-opus-20240229", - "claude-3-sonnet-20240229", "claude-3-haiku-20240307", + "claude-3-sonnet-20240229", + "claude-3-opus-20240229", ], + ollama: ["llama3.1"], } type ModelProvider = keyof typeof AVAILABLE_MODELS @@ -98,6 +110,7 @@ export function EditMember({ reset, control, watch, + setValue, formState: { isSubmitting, errors, isDirty, isValid }, } = useForm({ mode: "onBlur", @@ -151,7 +164,6 @@ export function EditMember({ // Watch the type field to determine whether to disable multiselect const memberType = watch("type") - const selectedProvider = watch("provider") as ModelProvider const skillOptions = skills ? skills.data.map((skill) => ({ @@ -169,6 +181,14 @@ export function EditMember({ })) : [] + const modelProvider = watch("provider") as ModelProvider + const modelOptions: ModelOption[] = AVAILABLE_MODELS[modelProvider].map( + (model) => ({ + label: model, + value: model, + }), + ) + return ( @@ -302,11 +322,18 @@ export function EditMember({ ) : null} - + Provider - - Model - - + { + return ( + + Model + onChange(newValue?.value)} + onBlur={onBlur} + value={{ value: value, label: value }} + options={modelOptions} + useBasicStyles + /> + + If a model is not listed, you can type it in. + + + ) + }} + /> + {(modelProvider === "openai" || modelProvider === "ollama") && ( + + Proxy Provider + + {modelProvider === "ollama" && ( + + Default url: http://host.docker.internal:11434 + + )} + + )}