Skip to content

Commit

Permalink
Rewrote the tool use handler to be synchronous, I don't think the Ope…
Browse files Browse the repository at this point in the history
…nAI Python SDK does a good job with async
  • Loading branch information
dividor committed Jun 16, 2024
1 parent 9b2e0fe commit f96a23c
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 144 deletions.
22 changes: 21 additions & 1 deletion templates/ai_memory_judge_prompt.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ You have specific memories you can match on, as well as generic skills that migh

For example, if the user want to plot a scatter plot of food prices in Uganda, the top hit would be a match on a
memory like 'plot a scatter graph of prices in Uganda in the last 5 years. If a specific memory match
like this doesn't exist, match on a generik skill that can be used with the user's input parameters, for example
like this doesn't exist, match on a generic skill that can be used with the user's input parameters, for example
'plot a scatter graph of prices in a country'.

Key points to consider:
Expand All @@ -17,6 +17,26 @@ Key points to consider:
- If you user has a very general question, eg 'What data do you have' but the possible match is more specific, eg 'what data do you have for region X', it is not a match
- 'Plot population pyramids' means the same as 'plot a pyramid plot'

Examples of Matches:

User intent: generate a line chart of Maize prices by year for 2013-2023 for Chad
Match: plot a line chart of commodity prices monthly relative change for Chad from 2008-01-01 using HDX data as an image
Reason: There is a general skill which can be used for the user's request

User intent: plot population pyramids for chad
Match: plot population pyramids for Chad
Reason: The user asked for exactly this, sample plot, same country

Examples that are NOT matches::

User intent: plot a map of population by state in Haiti
Match: I would like a plot of population by state in Mali
Reason: The user asked for Haiti, not Mali

User intent: give me a plot of population in Chad
Match: plot population pyramids for Chad
Reason: The user asked for a general plot of population but didn't specify a plot type

The user asked for this:

{{ user_input }}
Expand Down
2 changes: 1 addition & 1 deletion templates/generate_intent_from_history_prompt.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ Here is the chat history:
Important:

- Be careful to note the chat history, they may have just asked a follow-up question
- Put more emphasis on their last request
- Put more emphasis on their last input, it has a stronger influence on the intent than earlier inputs in chat_history
- include all entities such as places

Intent format:
Expand Down
1 change: 0 additions & 1 deletion templates/openai_assistant_prompt.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@ WHERE

Conversely, if you do not exclude the aggregate data, you will get a mix of aggregated and disaggregated data.


4. HDX Shape files

NEVER query the database for shape files, they are too large.
Expand Down
238 changes: 97 additions & 141 deletions ui/chat-chainlit-assistant/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
config.ui.name = bot_name


class EventHandler(AsyncAssistantEventHandler):
class EventHandler(AssistantEventHandler):

def __init__(self, assistant_name: str) -> None:
"""
Expand All @@ -75,185 +75,142 @@ def __init__(self, assistant_name: str) -> None:
self.current_message: cl.Message = None
self.current_step: cl.Step = None
self.current_tool_call = None
self.current_message_text = ""
self.assistant_name = assistant_name

async def on_text_created(self, text) -> None:
@override
def on_event(self, event):
"""
Handles the event when a new text is created.
Handles the incoming event and performs the necessary actions based on the event type.
Args:
text: The newly created text.
event: The event object containing information about the event.
Returns:
None
"""
self.current_message = await cl.Message(
author=self.assistant_name, content=""
).send()

async def on_text_delta(self, delta, snapshot):
"""
Handles the text delta event.
Parameters:
- delta: The text delta object.
- snapshot: The current snapshot of the document.
Returns:
- None
"""
if delta.value is not None:
await self.current_message.stream_token(delta.value)
print(event.event)
run_id = event.data.id
if event.event == "thread.message.created":
self.current_message = run_sync(cl.Message(content="").send())
self.current_message_text = ""
print("Run started")
if event.event == "thread.message.completed":
self.handle_message_completed(event.data, run_id)
elif event.event == "thread.run.requires_action":
self.handle_requires_action(event.data, run_id)
elif event.event == "thread.message.delta":
self.handle_message_delta(event.data)
else:
print(event.data)
print(f"Unhandled event: {event.event}")

async def on_text_done(self, text):
def handle_message_delta(self, data):
"""
Callback method called when text input is done.
Handles the message delta data.
Args:
text (str): The text input provided by the user.
data: The message delta data.
Returns:
None
"""
await self.current_message.update()

async def on_tool_call_created(self, tool_call):
"""
Callback method called when a tool call is created.
Args:
tool_call: The tool call object representing the created tool call.
"""
self.current_tool_call = tool_call.id
self.current_step = cl.Step(name=tool_call.type, type="tool")
self.current_step.language = "python"
self.current_step.created_at = utc_now()
await self.current_step.send()
for content in data.delta.content:
if content.type == "text":
content = content.text.value
self.current_message_text += content
run_sync(self.current_message.stream_token(content))
elif content.type == "image_file":
file_id = content.image_file.file_id
image_data = sync_openai_client.files.content(file_id)
image_data_bytes = image_data.read()
png_file = f"{images_loc}{file_id}.png"
print(f"Writing image to {png_file}")
with open(png_file, "wb") as file:
file.write(image_data_bytes)
image = cl.Image(path=png_file, display="inline", size="large")
print(f"Image: {png_file}")
if not self.current_message.elements:
self.current_message.elements = []
self.current_message.elements.append(image)
run_sync(self.current_message.update())
else:
print(f"Unhandled delta type: {content.type}")

async def on_tool_call_delta(self, delta, snapshot):
def handle_message_completed(self, data, run_id):
"""
Handles the tool call delta event.
Handles the completion of a message.
Args:
delta (ToolCallDelta): The delta object representing the tool call event.
snapshot (Snapshot): The snapshot object representing the current state.
data: The data associated with the completed message.
run_id: The ID of the message run.
Returns:
None
"""
print(f"Tool call delta: {delta.type}")
if snapshot.id != self.current_tool_call:
self.current_tool_call = snapshot.id
self.current_step = cl.Step(name=delta.type, type="tool")
self.current_step.language = "python"
self.current_step.start = utc_now()
await self.current_step.send()

if delta.type == "code_interpreter":
if delta.code_interpreter.outputs:
for output in delta.code_interpreter.outputs:
if output.type == "logs":
error_step = cl.Step(name=delta.type, type="tool")
error_step.is_error = True
error_step.output = output.logs
error_step.language = "markdown"
error_step.start = self.current_step.start
error_step.end = utc_now()
await error_step.send()
else:
if delta.code_interpreter.input:
await self.current_step.stream_token(delta.code_interpreter.input)
# Add footer to self message. We have to start a new message so it's in right order
# TODO combine streaming with image and footer
run_sync(self.current_message.update())
self.current_message = run_sync(cl.Message(content="").send())

word_count = len(self.current_message_text.split())
if word_count > 10:
run_sync(self.current_message.stream_token(llm_footer))
run_sync(self.current_message.update())

async def on_tool_call_done(self, tool_call):
def handle_requires_action(self, data, run_id):
"""
Callback method called when a tool call is done.
Handles the required action by executing the specified tools and submitting the tool outputs.
Args:
tool_call: The tool call object representing the completed tool call.
data: The data containing the required action information.
run_id: The ID of the current run.
Returns:
None
"""
print("Tool call done!")
# Turning this off, analysis would stop suddenly
self.current_step.end = utc_now()
await self.current_step.update()
tool_outputs = []

async def on_image_file_done(self, image_file):
"""
Callback function called when an image file is done processing.
Args:
image_file: The image file object that has finished processing.
for tool in data.required_action.submit_tool_outputs.tool_calls:
print(tool)

Returns:
None
"""
image_id = image_file.file_id
response = await async_openai_client.files.with_raw_response.content(image_id)
image_element = cl.Image(
name=image_id, content=response.content, display="inline", size="large"
)
if not self.current_message.elements:
self.current_message.elements = []
self.current_message.elements.append(image_element)
await self.current_message.update()
function_name = tool.function.name
function_args = tool.function.arguments

async def on_end(
self,
):
print("\n end assistant > ", self.current_run_step_snapshot)
function_output = run_function(function_name, function_args)

async def on_exception(self, exception: Exception) -> None:
print("\n Exception > ", self.current_run_step_snapshot)
tool_outputs.append({"tool_call_id": tool.id, "output": function_output})

async def on_timeout(self) -> None:
print("\n Timeout > ", self.current_run_step_snapshot)
print("TOOL OUTPUTS: ")

@override
async def on_event(self, event):
# Retrieve events that are denoted with 'requires_action'
# since these will have our tool_calls
print(event.event)
if event.event == "thread.run.requires_action":
tool_outputs = []
for tool in event.data.required_action.submit_tool_outputs.tool_calls:
print(tool)
print(tool_outputs)

function_name = tool.function.name
function_args = tool.function.arguments
# Submit all tool_outputs at the same time
self.submit_tool_outputs(tool_outputs, run_id)

function_output = asyncio.run(
run_function(function_name, function_args)
)
def submit_tool_outputs(self, tool_outputs, run_id):
"""
Submits the tool outputs to the current run.
tool_outputs.append(
{"tool_call_id": tool.id, "output": function_output}
)
Args:
tool_outputs (list): A list of tool outputs to be submitted.
run_id (str): The ID of the current run.
print("TOOL OUTPUTS: ")

print(tool_outputs)

# Streaming
with sync_openai_client.beta.threads.runs.submit_tool_outputs_stream(
thread_id=self.current_run.thread_id,
run_id=self.current_run.id,
tool_outputs=tool_outputs,
event_handler=AssistantEventHandler(),
) as stream:
print("Tool output submitted successfully")
msg = await cl.Message(author=self.assistant_name, content="").send()
response = ""
for text in stream.text_deltas:
response += text
await msg.stream_token(text)
if len(response) > 5:
await msg.stream_token(llm_footer)
await msg.update()


async def run_function(function_name, function_args):
Returns:
None
"""
with sync_openai_client.beta.threads.runs.submit_tool_outputs_stream(
thread_id=self.current_run.thread_id,
run_id=self.current_run.id,
tool_outputs=tool_outputs,
event_handler=EventHandler(assistant_name=self.assistant_name),
) as stream:
# Needs this line, or it doesn't work! :)
for text in stream.text_deltas:
print(text, end="", flush=True)


def run_function(function_name, function_args):
"""
Run a function with the given name and arguments.
Expand Down Expand Up @@ -283,7 +240,7 @@ async def run_function(function_name, function_args):
return output


def print(*tup):
def print_to_log(*tup):
"""
Custom print function that logs the output using the logger.
Expand Down Expand Up @@ -624,13 +581,12 @@ async def main(message: cl.Message):

# Create and Stream a Run
print(f"Creating and streaming a run {assistant.id}")
async with async_openai_client.beta.threads.runs.stream(
with sync_openai_client.beta.threads.runs.stream(
thread_id=thread_id,
assistant_id=assistant.id,
event_handler=EventHandler(assistant_name=assistant.name),
# max_completion_tokens=20000
) as stream:
await stream.until_done()
stream.until_done()


@cl.on_audio_chunk
Expand Down

0 comments on commit f96a23c

Please sign in to comment.