Skip to content

Commit

Permalink
WIP websocket callbacks
Browse files Browse the repository at this point in the history
  • Loading branch information
bmquinn committed Dec 9, 2024
1 parent 791f675 commit 9343301
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 67 deletions.
91 changes: 28 additions & 63 deletions chat/src/agent/agent_handler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, Dict, Optional, Union, List
from uuid import UUID
# from websocket import Websocket
from websocket import Websocket

from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.messages import BaseMessage
Expand All @@ -9,73 +9,38 @@


class AgentHandler(BaseCallbackHandler):
def on_chat_model_start(self, serialized: Dict[str, Any], messages: List[List[BaseMessage]], *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, **kwargs: Any) -> Any:
pass
def on_llm_new_token(self, token: str, metadata: Optional[dict[str, Any]], **kwargs: Any) -> Any:
socket: Websocket = metadata.get("socket", None)

def add_handler(self, handler: BaseCallbackHandler, inherit: bool = True) -> None:
pass

def remove_handler(self, handler: BaseCallbackHandler) -> None:
pass

def set_handlers(self, handlers: List[BaseCallbackHandler], inherit: bool = True) -> None:
pass

def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> Any:
pass

def on_llm_new_token(self, token: str, **kwargs: Any) -> Any:
pass

def on_llm_end(self, response: LLMResult, **kwargs: Any) -> Any:
pass
if socket is None:
raise ValueError("Socket not defined in agent handler via metadata")

socket.send("test")

def on_llm_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> Any:
pass
# def on_tool_start(self, serialized: Dict[str, Any], input_str: str, metadata: Optional[dict[str, Any]], **kwargs: Any) -> Any:
# print(f"on_tool_start: {serialized, input_str}")
# # socket: Websocket = metadata.get("socket", None)

def on_chain_start(self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any) -> Any:
pass
# # if socket is None:
# # raise ValueError("Socket not defined in agent handler via metadata")

# # socket.send(f"🎉 I'm working on it")

def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> Any:
pass
# def on_tool_end(self, output: str, **kwargs: Any) -> Any:
# print(f"on_tool_end: {output}")
# pass

def on_chain_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> Any:
pass
# def on_text(self, text: str, **kwargs: Any) -> Any:
# print(f"on_text: {text}")
# pass

# def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
# print(f"on_agent_action: {action}")
# pass

def on_tool_start(self, serialized: Dict[str, Any], input_str: str, metadata: Optional[dict[str, Any]], **kwargs: Any) -> Any:
print(f"on_tool_start: {serialized, input_str}")
print(f"on_tool_start kawrgs: {kwargs}")
# socket: Websocket = metadata.get("socket", None)

# if socket is None:
# raise ValueError("Socket not defined in agent handler via metadata")
socket: Websocket = metadata.get("socket", None)
if socket is None:
raise ValueError("Socket not defined in agent handler via metadata")

# socket.send(f"🎉 I'm working on it")

def on_tool_end(self, output: str, **kwargs: Any) -> Any:
print(f"on_tool_end: {output}")
print(f"on_tool_end kwargs: {kwargs}")
pass

def on_tool_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> Any:
pass

def on_text(self, text: str, **kwargs: Any) -> Any:
pass

def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
print(f"on_agent_action: {action}")
print(f"on_agent_action kwargs: {kwargs}")
pass

def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
pass


"""
on_tool_start (A tool is starting): (
{'name': 'search', 'description': "Perform a semantic search of Northwestern University Library digital collections. When answering a search query, ground your answer in the context of the results with references to the document's metadata."},
"{'query': 'World War II Posters visual themes'}"
)
"""
socket.send(f"🎉 I'm working on it")
2 changes: 1 addition & 1 deletion chat/src/agent/search_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,4 @@ def call_model(state: MessagesState):
checkpointer = MemorySaver()

# Compile the graph
search_agent = workflow.compile(checkpointer=checkpointer, debug=True)
search_agent = workflow.compile(checkpointer=checkpointer, debug=False)
6 changes: 3 additions & 3 deletions chat/src/handlers/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
import os
from datetime import datetime
from event_config import EventConfig
from honeybadger import honeybadger
# from honeybadger import honeybadger
from agent.search_agent import search_agent
from langchain_core.messages import HumanMessage
from agent.agent_handler import AgentHandler

honeybadger.configure()
logging.getLogger("honeybadger").addHandler(logging.StreamHandler())
# honeybadger.configure()
# logging.getLogger("honeybadger").addHandler(logging.StreamHandler())

RESPONSE_TYPES = {
"base": ["answer", "ref"],
Expand Down
1 change: 1 addition & 0 deletions chat/src/helpers/hybrid_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def hybrid_query(query: str, model_id: str, vector_field: str = "embedding", k:
result = {
"size": kwargs.get("size", 20),
"_source": {
"include": ["title", "description", "collection.title", "id", "collection.id"],
"exclude": ["embedding"]
},
"query": {
Expand Down
6 changes: 6 additions & 0 deletions chat/src/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,9 @@ def send(self, data):
else:
self.client.post_to_connection(Data=data_as_bytes, ConnectionId=self.connection_id)
return data

def __str__(self):
return f"Websocket({self.connection_id}, {self.ref})"

def __repr__(self):
return str(self)

0 comments on commit 9343301

Please sign in to comment.