Skip to content

Commit

Permalink
Enable ollama models as agents (#98)
Browse files Browse the repository at this point in the history
  • Loading branch information
StreetLamb authored Aug 3, 2024
1 parent f299d0c commit cbb0045
Show file tree
Hide file tree
Showing 17 changed files with 266 additions and 92 deletions.
30 changes: 30 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@
- [Writing a Custom Skill using LangChain](#writing-a-custom-skill-using-langchain)
- [Retrieval Augmented Generation (RAG)](#retrieval-augmented-generation-rag)
- [Customising embedding models](#customising-embedding-models)
- [Using Open Source Models](#using-open-source-models)
- [Using Open Source Models with Ollama](#using-open-source-models-with-ollama)
- [Choosing the Right Models](#choosing-the-right-models)
- [Using Open Source Models without Ollama](#using-open-source-models-without-ollama)
- [Guides](#guides)
- [Creating Your First Hierarchical Team](#creating-your-first-hierarchical-team)
- [Equipping Your Team Member with Skills](#equipping-your-team-member-with-skills)
Expand Down Expand Up @@ -61,6 +65,7 @@ and many many more!
- **Tool Calling**: Enable your agents to utilize external tools and APIs.
- **Retrieval Augmented Generation**: Enable your agents to reason with your internal knowledge base.
- **Human-In-The-Loop**: Enable human approval before tool calling.
- **Open Source Models**: Use open-source LLM models such as llama, Gemma and Phi.
- **Easy Deployment**: Deploy Tribe effortlessly using Docker.
- **Multi-Tenancy**: Manage and support multiple users and teams.

Expand Down Expand Up @@ -198,6 +203,31 @@ DENSE_EMBEDDING_MODEL=BAAI/bge-small-en-v1.5 # Change this
> [!WARNING]
> If your existing and new embedding models have different vector dimensions, you may need to recreate your Qdrant collection. You can delete the collection through the Qdrant Dashboard at [http://qdrant.localhost/dashboard](http://qdrant.localhost/dashboard). Therefore, it is better to plan ahead which embedding model is most suitable for your workflows.
### Using Open Source Models

Open source models are becoming cheaper and easier to run, and some even match the performance of closed models. You might prefer using them for their privacy and cost benefits. If you are running Tribe locally and want to use open source models, I would recommend Ollama for its ease of use.

#### Using Open Source Models with Ollama
1. **Install Ollama:** First, set up Ollama on your device. You can find the instructions in [Ollama's repo](https://github.com/ollama/ollama).
2. **Download Models:** Download your preferred models from Ollama
3. **Configure your agents:**
- Update the agent's provider to `ollama`.
- Paste the downloaded model's name (e.g., `llama3.1:8b`) into the model input field.
- By default, Tribe will run on `http://host.docker.internal:11434`, which maps to `https://localhost:11434`. This setup allows Tribe to communicate with the default Ollama host. If your setup uses a different host, specify the new host in the 'Base URL' input field.

#### Choosing the Right Models
There are hundreds of open source models in [Ollama's library](https://ollama.com/library) suitable for different tasks. Here’s how to choose the right one for your use case:
- **Tool Calling Models:** If you are planning to equip agents with specific skills, use models like `Llama3.1`, `Mistral Nemo`, `Firefunction V2`, or `Command-R +` and others that support tool calling.
- **For Creative, Reasoning and other Tasks:** You have more flexibility. You may stick to tool calling capable models or consider models like `gemma2` or `phi3`.

#### Using Open Source Models without Ollama

If you’re not planning to use Ollama, you can still run open source models compatible with the [OpenAI chat completions API](https://platform.openai.com/docs/api-reference/introduction).

Steps:
1. **Edit Your Agent:** Select 'OpenAI' as your model provider.
2. **Specify Endpoint:** Under 'Base URL', specify the model’s inference endpoint.

### Guides

#### Creating Your First Hierarchical Team
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from alembic import op
import sqlalchemy as sa
import sqlmodel.sql.sqltypes
from sqlalchemy.dialects.postgresql import UUID


# revision identifiers, used by Alembic.
Expand All @@ -20,9 +21,9 @@
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('writes',
sa.Column('thread_id', sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.Column('thread_ts', sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.Column('task_id', sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.Column('thread_id', UUID(as_uuid=True), nullable=False),
sa.Column('thread_ts', UUID(as_uuid=True), nullable=False),
sa.Column('task_id', UUID(as_uuid=True), nullable=False),
sa.Column('idx', sa.Integer(), nullable=False),
sa.Column('channel', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('value', sa.LargeBinary(), nullable=False),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""add base_url col to members table
Revision ID: 38a9c73bfce2
Revises: 6e7c33ddf30f
Create Date: 2024-07-29 15:18:15.979804
"""
from alembic import op
import sqlalchemy as sa
import sqlmodel.sql.sqltypes


# revision identifiers, used by Alembic.
revision = '38a9c73bfce2'
down_revision = '6e7c33ddf30f'
branch_labels = None
depends_on = None


def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('member', sa.Column('base_url', sqlmodel.sql.sqltypes.AutoString(), nullable=True))
# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('member', 'base_url')
# ### end Alembic commands ###
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from alembic import op
import sqlalchemy as sa
import sqlmodel.sql.sqltypes
from sqlalchemy.dialects.postgresql import UUID



# revision identifiers, used by Alembic.
Expand All @@ -21,7 +23,7 @@ def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('thread',
sa.Column('query', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('id', sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.Column('id', UUID(as_uuid=True), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('team_id', sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(['team_id'], ['team.id'], ),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"""
from alembic import op
import sqlalchemy as sa
import sqlmodel.sql.sqltypes
from sqlalchemy.dialects.postgresql import UUID


# revision identifiers, used by Alembic.
Expand All @@ -20,9 +20,9 @@
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('checkpoints',
sa.Column('thread_id', sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.Column('thread_ts', sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.Column('parent_ts', sqlmodel.sql.sqltypes.GUID(), nullable=True),
sa.Column('thread_id', UUID(as_uuid=True), nullable=False),
sa.Column('thread_ts', UUID(as_uuid=True), nullable=False),
sa.Column('parent_ts', UUID(as_uuid=True), nullable=True),
sa.Column('checkpoint', sa.LargeBinary(), nullable=False),
sa.Column('metadata', sa.LargeBinary(), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
Expand Down
13 changes: 12 additions & 1 deletion backend/app/core/graph/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def convert_hierarchical_team_to_dict(
teams[leader_name] = GraphTeam(
name=leader_name,
model=member.model,
base_url=member.base_url,
role=member.role,
backstory=member.backstory or "",
members={},
Expand Down Expand Up @@ -126,6 +127,7 @@ def convert_hierarchical_team_to_dict(
tools=tools,
provider=member.provider,
model=member.model,
base_url=member.base_url,
temperature=member.temperature,
interrupt=member.interrupt,
)
Expand All @@ -136,6 +138,7 @@ def convert_hierarchical_team_to_dict(
role=member.role,
provider=member.provider,
model=member.model,
base_url=member.base_url,
temperature=member.temperature,
)
for nei_id in out_counts[member_id]:
Expand Down Expand Up @@ -196,6 +199,7 @@ def convert_sequential_team_to_dict(members: list[Member]) -> Mapping[str, Graph
tools=tools,
provider=memberModel.provider,
model=memberModel.model,
base_url=memberModel.base_url,
temperature=memberModel.temperature,
interrupt=memberModel.interrupt,
)
Expand Down Expand Up @@ -288,6 +292,7 @@ def create_hierarchical_graph(
LeaderNode(
teams[leader_name].provider,
teams[leader_name].model,
teams[leader_name].base_url,
teams[leader_name].temperature,
).delegate # type: ignore[arg-type]
),
Expand All @@ -298,6 +303,7 @@ def create_hierarchical_graph(
SummariserNode(
teams[leader_name].provider,
teams[leader_name].model,
teams[leader_name].base_url,
teams[leader_name].temperature,
).summarise # type: ignore[arg-type]
),
Expand All @@ -312,6 +318,7 @@ def create_hierarchical_graph(
WorkerNode(
member.provider,
member.model,
member.base_url,
member.temperature,
).work # type: ignore[arg-type]
),
Expand Down Expand Up @@ -385,7 +392,10 @@ def create_sequential_graph(
member.name,
RunnableLambda(
SequentialWorkerNode(
member.provider, member.model, member.temperature
member.provider,
member.model,
member.base_url,
member.temperature,
).work # type: ignore[arg-type]
),
)
Expand Down Expand Up @@ -489,6 +499,7 @@ async def generator(
members=member_dict, # type: ignore[arg-type]
provider=first_member.provider,
model=first_member.model,
base_url=first_member.base_url,
temperature=first_member.temperature,
),
"messages": [],
Expand Down
34 changes: 30 additions & 4 deletions backend/app/core/graph/members.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
RunnableSerializable,
)
from langchain_core.tools import BaseTool
from langchain_ollama import ChatOllama
from langchain_openai import ChatOpenAI
from langgraph.graph import add_messages
from pydantic import BaseModel, Field
Expand Down Expand Up @@ -58,6 +59,10 @@ class GraphPerson(BaseModel):
role: str = Field(description="Role of the person")
provider: str = Field(description="The provider for the llm model")
model: str = Field(description="The llm model to use for this person")
base_url: str | None = Field(
default=None,
description="Use a proxy to serve llm model",
)
temperature: float = Field(description="The temperature of the llm model")
backstory: str = Field(
description="Description of the person's experience, motives and concerns."
Expand Down Expand Up @@ -94,6 +99,9 @@ class GraphTeam(BaseModel):
)
provider: str = Field(description="The provider of the team leader's llm model")
model: str = Field(description="The llm model to use for this team leader")
base_url: str | None = Field(
default=None, description="Use a proxy to serve llm model"
)
temperature: float = Field(
description="The temperature of the team leader's llm model"
)
Expand Down Expand Up @@ -146,10 +154,28 @@ class ReturnTeamState(TypedDict):


class BaseNode:
def __init__(self, provider: str, model: str, temperature: float):
self.model = init_chat_model(
model, model_provider=provider, temperature=temperature, streaming=True
)
def __init__(
self, provider: str, model: str, base_url: str | None, temperature: float
):
# If using proxy, then we need to pass base url
# TODO: Include ollama here once langchain-ollama bug is fixed
if provider in ["openai"] and base_url:
self.model = init_chat_model(
model,
model_provider=provider,
temperature=temperature,
base_url=base_url,
)
elif provider == "ollama":
self.model = ChatOllama(
model=model,
temperature=temperature,
base_url=base_url if base_url else "http://host.docker.internal:11434",
)
else:
self.model = init_chat_model(
model, model_provider=provider, temperature=0, streaming=True
)
self.final_answer_model = init_chat_model(
model, model_provider=provider, temperature=0, streaming=True
)
Expand Down
1 change: 1 addition & 0 deletions backend/app/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ class MemberBase(SQLModel):
model: str = "gpt-4o-mini"
temperature: float = 0.7
interrupt: bool = False
base_url: str | None = None


class MemberCreate(MemberBase):
Expand Down
Loading

0 comments on commit cbb0045

Please sign in to comment.