Skip to content

Commit

Permalink
Allow sequential agents to wait for human input (#129)
Browse files Browse the repository at this point in the history
* Add human tool to allow agents to ask human for inputs

* Update sequential graph to add ask-human tool logic. Update generator fn to send ask-human interrupt and process human replies.

* Update ChatTeam component to request for ask-human tool replies and submit them

* Upgrade langgraph version to v0.2.14

* Modify the spacings within MessageBox

* Add ask-human tool logic in hierarchical graph

* Refactor EditMember to make configuration easier. Remove human tool for hierarchical workflows

* Enable tools and interrupt for root node of sequential workflows
  • Loading branch information
StreetLamb authored Aug 25, 2024
1 parent 7dd0b3d commit 50ae80d
Show file tree
Hide file tree
Showing 12 changed files with 269 additions and 118 deletions.
198 changes: 135 additions & 63 deletions backend/app/core/graph/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,12 @@
)
from langchain_core.runnables import RunnableLambda
from langchain_core.runnables.config import RunnableConfig
from langchain_core.tools import BaseTool
from langgraph.checkpoint.base import BaseCheckpointSaver
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from langgraph.graph import END, StateGraph
from langgraph.graph.graph import CompiledGraph
from langgraph.prebuilt import (
ToolNode,
)
from langgraph.prebuilt import ToolNode
from psycopg import AsyncConnection

from app.core.config import settings
Expand Down Expand Up @@ -241,28 +240,46 @@ def should_continue(state: TeamState) -> str:
"""Determine if graph should go to tool node or not. For tool calling agents."""
messages: list[AnyMessage] = state["messages"]
if messages and isinstance(messages[-1], AIMessage) and messages[-1].tool_calls:
return "call_tools"
# TODO: what if multiple tool_calls?
for tool_call in messages[-1].tool_calls:
if tool_call["name"] == "AskHuman":
return "call_human"
else:
return "call_tools"
else:
return "continue"


def create_tools_condition(
current_member_name: str, next_member_name: str
current_member_name: str,
next_member_name: str,
tools: list[GraphSkill | GraphUpload],
) -> dict[Hashable, str]:
"""Creates the mapping for conditional edges
The tool node must be in format: '{current_member_name}_tools'
Args:
current_member_name (str): The name of the member that is calling the tool
next_member_name (str): The name of the next member after the current member processes the tool response. Can be END.
tools: List of tools that the agent has.
"""
return {
# If `tools`, then we call the tool node.
"call_tools": f"{current_member_name}_tools",
mapping: dict[Hashable, str] = {
# Else continue to the next node
"continue": next_member_name,
}

for tool in tools:
if tool.name == "ask-human":
mapping["call_human"] = f"{current_member_name}_askHuman_tool"
else:
mapping["call_tools"] = f"{current_member_name}_tools"
return mapping


def ask_human(state: TeamState) -> None:
"""Dummy node for ask human tool"""
pass


def create_hierarchical_graph(
teams: dict[str, GraphTeam],
Expand All @@ -283,8 +300,7 @@ def create_hierarchical_graph(
dict: A dictionary representing the graph of teams.
"""
build = StateGraph(TeamState)
# Create a list to store member names that require human intervention before tool calling
interrupt_member_names = []
interrupt_member_names = [] # List to store members that require human intervention before tool calling
# Add the start and end node
build.add_node(
leader_name,
Expand Down Expand Up @@ -323,19 +339,28 @@ def create_hierarchical_graph(
).work # type: ignore[arg-type]
),
)
# if member can call tools, then add tool node
if len(member.tools) >= 1:
build.add_node(
f"{name}_tools",
ToolNode([tool.tool for tool in member.tools]),
)
# After tools node is called, agent node is called next.
build.add_edge(f"{name}_tools", name)
# Check if member requires human intervention before tool calling
if member.interrupt:
interrupt_member_names.append(f"{name}_tools")
if member.tools:
normal_tools: list[BaseTool] = []

for tool in member.tools:
if tool.name == "ask-human":
# Handling Ask-Human tool
interrupt_member_names.append(f"{name}_askHuman_tool")
build.add_node(f"{name}_askHuman_tool", ask_human)
build.add_edge(f"{name}_askHuman_tool", name)
else:
normal_tools.append(tool.tool)

if normal_tools:
# Add node for normal tools
build.add_node(f"{name}_tools", ToolNode(normal_tools))
build.add_edge(f"{name}_tools", name)

# Interrupt for normal tools only if member.interrupt is True
if member.interrupt:
interrupt_member_names.append(f"{name}_tools")

elif isinstance(member, GraphLeader):
# subgraphs do not require memory
subgraph = create_hierarchical_graph(
teams, leader_name=name, checkpointer=checkpointer
)
Expand All @@ -346,16 +371,17 @@ def create_hierarchical_graph(
)
else:
continue

# If member has tools, we create conditional edge to either tool node or back to leader.
if isinstance(member, GraphMember) and len(member.tools) >= 1:
if isinstance(member, GraphMember) and member.tools:
build.add_conditional_edges(
name, should_continue, create_tools_condition(name, leader_name)
name,
should_continue,
create_tools_condition(name, leader_name, member.tools),
)
# Check if member requires human intervention before tool calling
if member.interrupt:
interrupt_member_names.append(f"{member.name}_tools")
else:
build.add_edge(name, leader_name)

conditional_mapping: dict[Hashable, str] = {v: v for v in members}
conditional_mapping["FINISH"] = "FinalAnswer"
build.add_conditional_edges(leader_name, router, conditional_mapping)
Expand Down Expand Up @@ -383,11 +409,11 @@ def create_sequential_graph(
Returns:
CompiledGraph: The compiled graph representing the sequential workflow.
"""
members: list[GraphMember] = []
graph = StateGraph(TeamState)
# Create a list to store member names that require human intervention before tool calling
interrupt_member_names = []
for i, member in enumerate(team.values()):
interrupt_member_names = [] # List to store members that require human intervention before it is called
members = list(team.values())

for i, member in enumerate(members):
graph.add_node(
member.name,
RunnableLambda(
Expand All @@ -399,40 +425,56 @@ def create_sequential_graph(
).work # type: ignore[arg-type]
),
)
# if member can call tools, then add tool node
if len(member.tools) >= 1:
graph.add_node(
f"{member.name}_tools",
ToolNode([tool.tool for tool in member.tools]),
)
# After tools node is called, agent node is called next.
graph.add_edge(f"{member.name}_tools", member.name)
# Check if member requires human intervention before tool calling
if member.interrupt:
interrupt_member_names.append(f"{member.name}_tools")

if member.tools:
normal_tools: list[BaseTool] = []

for tool in member.tools:
if tool.name == "ask-human":
# Handling Ask-Human tool
interrupt_member_names.append(f"{member.name}_askHuman_tool")
graph.add_node(f"{member.name}_askHuman_tool", ask_human)
graph.add_edge(f"{member.name}_askHuman_tool", member.name)
else:
normal_tools.append(tool.tool)

if normal_tools:
# Add node for normal tools
graph.add_node(f"{member.name}_tools", ToolNode(normal_tools))
graph.add_edge(f"{member.name}_tools", member.name)

# Interrupt for normal tools only if member.interrupt is True
if member.interrupt:
interrupt_member_names.append(f"{member.name}_tools")

if i > 0:
# if previous member has tools, then the edge should conditionally call tool node
if len(members[i - 1].tools) >= 1:
previous_member = members[i - 1]
if previous_member.tools:
graph.add_conditional_edges(
members[i - 1].name,
previous_member.name,
should_continue,
create_tools_condition(members[i - 1].name, member.name),
create_tools_condition(
previous_member.name, member.name, previous_member.tools
),
)
else:
graph.add_edge(members[i - 1].name, member.name)
members.append(member)
# Add the conditional edges for the final node if it uses tools
if len(members[-1].tools) >= 1:
graph.add_edge(previous_member.name, member.name)

# Handle the final member's tools
final_member = members[-1]
if final_member.tools:
graph.add_conditional_edges(
members[-1].name,
final_member.name,
should_continue,
create_tools_condition(members[-1].name, END),
create_tools_condition(final_member.name, END, final_member.tools),
)
else:
graph.add_edge(members[-1].name, END)
graph.add_edge(final_member.name, END)

graph.set_entry_point(members[0].name)
return graph.compile(
checkpointer=checkpointer, interrupt_before=interrupt_member_names
checkpointer=checkpointer,
interrupt_before=interrupt_member_names,
)


Expand Down Expand Up @@ -531,14 +573,34 @@ async def generator(
for tool_call in tool_calls
]
}
if interrupt.rejection_message:
if interrupt.tool_message:
state["messages"].append(
HumanMessage(
content=interrupt.rejection_message,
content=interrupt.tool_message,
name="user",
id=str(uuid4()),
)
)
elif interrupt and interrupt.decision == InterruptDecision.REPLIED:
current_values = await root.aget_state(config)
messages = current_values.values["messages"]
if (
messages
and isinstance(messages[-1], AIMessage)
and interrupt.tool_message
):
tool_calls = messages[-1].tool_calls
state = {
"messages": [
ToolMessage(
tool_call_id=tool_call["id"],
content=interrupt.tool_message,
name="AskHuman",
)
for tool_call in tool_calls
if tool_call["name"] == "AskHuman"
]
}
async for event in root.astream_events(state, version="v2", config=config):
response = event_to_response(event)
if response:
Expand All @@ -550,13 +612,23 @@ async def generator(
message = snapshot.values["messages"][-1]
if not isinstance(message, AIMessage):
return

response = ChatResponse(
type="interrupt",
name="interrupt",
tool_calls=message.tool_calls,
id=str(uuid4()),
)
# Determine if should return default or askhuman interrupt based on whether AskHuman tool was called.
for tool_call in message.tool_calls:
if tool_call["name"] == "AskHuman":
response = ChatResponse(
type="interrupt",
name="human",
tool_calls=message.tool_calls,
id=str(uuid4()),
)
break
else:
response = ChatResponse(
type="interrupt",
name="interrupt",
tool_calls=message.tool_calls,
id=str(uuid4()),
)
formatted_output = f"data: {response.model_dump_json()}\n\n"
yield formatted_output
except Exception as e:
Expand Down
27 changes: 20 additions & 7 deletions backend/app/core/graph/checkpoint/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,27 @@ def convert_checkpoint_tuple_to_messages(

last_message = all_messages[-1]
if last_message.type == "ai" and last_message.tool_calls:
formatted_messages.append(
ChatResponse(
type="interrupt",
name="interrupt",
tool_calls=last_message.tool_calls,
id=str(uuid4()),
# Check if any tool in last message is asking for human input
for tool_call in last_message.tool_calls:
if tool_call["name"] == "AskHuman":
formatted_messages.append(
ChatResponse(
type="interrupt",
name="human",
tool_calls=last_message.tool_calls,
id=str(uuid4()),
)
)
break
else:
formatted_messages.append(
ChatResponse(
type="interrupt",
name="interrupt",
tool_calls=last_message.tool_calls,
id=str(uuid4()),
)
)
)
return formatted_messages


Expand Down
7 changes: 7 additions & 0 deletions backend/app/core/graph/skills/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# mypy: disable-error-code="attr-defined, arg-type"
from langchain.pydantic_v1 import BaseModel
from langchain.tools import BaseTool
from langchain_community.tools import DuckDuckGoSearchRun, WikipediaQueryRun
Expand All @@ -6,6 +7,8 @@
WikipediaAPIWrapper,
)

from .human_tool import AskHuman

# from .calculator import multiply


Expand All @@ -26,6 +29,10 @@ class SkillInfo(BaseModel):
description="Get information from Yahoo Finance News.",
tool=YahooFinanceNewsTool(),
),
"ask-human": SkillInfo(
description=AskHuman.description,
tool=AskHuman,
),
# multiply.name: SkillInfo(
# description=multiply.description,
# tool=multiply,
Expand Down
11 changes: 11 additions & 0 deletions backend/app/core/graph/skills/human_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# This is an example showing how to create a simple calculator skill

from typing import Annotated

from langchain_core.tools import tool


@tool
def AskHuman(query: Annotated[str, "query to ask the human"]) -> None:
"""Ask the human a question to gather additional inputs"""
pass
3 changes: 2 additions & 1 deletion backend/app/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,12 @@ class ChatMessage(BaseModel):
class InterruptDecision(Enum):
APPROVED = "approved"
REJECTED = "rejected"
REPLIED = "replied"


class Interrupt(BaseModel):
decision: InterruptDecision
rejection_message: str | None = None
tool_message: str | None = None


class TeamChat(BaseModel):
Expand Down
Loading

0 comments on commit 50ae80d

Please sign in to comment.