diff --git a/templates/ai_memory_judge_prompt.jinja2 b/templates/ai_memory_judge_prompt.jinja2 index ebc38a81..4fb8a7af 100644 --- a/templates/ai_memory_judge_prompt.jinja2 +++ b/templates/ai_memory_judge_prompt.jinja2 @@ -8,7 +8,7 @@ You have specific memories you can match on, as well as generic skills that migh For example, if the user want to plot a scatter plot of food prices in Uganda, the top hit would be a match on a memory like 'plot a scatter graph of prices in Uganda in the last 5 years. If a specific memory match -like this doesn't exist, match on a generik skill that can be used with the user's input parameters, for example +like this doesn't exist, match on a generic skill that can be used with the user's input parameters, for example 'plot a scatter graph of prices in a country'. Key points to consider: @@ -17,6 +17,26 @@ Key points to consider: - If you user has a very general question, eg 'What data do you have' but the possible match is more specific, eg 'what data do you have for region X', it is not a match - 'Plot population pyramids' means the same as 'plot a pyramid plot' +Examples of Matches: + +User intent: generate a line chart of Maize prices by year for 2013-2023 for Chad +Match: plot a line chart of commodity prices monthly relative change for Chad from 2008-01-01 using HDX data as an image +Reason: There is a general skill which can be used for the user's request + +User intent: plot population pyramids for chad +Match: plot population pyramids for Chad +Reason: The user asked for exactly this, sample plot, same country + +Examples that are NOT matches:: + +User intent: plot a map of population by state in Haiti +Match: I would like a plot of population by state in Mali +Reason: The user asked for Haiti, not Mali + +User intent: give me a plot of population in Chad +Match: plot population pyramids for Chad +Reason: The user asked for a general plot of population but didn't specify a plot type + The user asked for this: {{ user_input }} diff --git a/templates/generate_intent_from_history_prompt.jinja2 b/templates/generate_intent_from_history_prompt.jinja2 index e99b88e7..dd88e2f4 100644 --- a/templates/generate_intent_from_history_prompt.jinja2 +++ b/templates/generate_intent_from_history_prompt.jinja2 @@ -13,7 +13,7 @@ Here is the chat history: Important: - Be careful to note the chat history, they may have just asked a follow-up question -- Put more emphasis on their last request +- Put more emphasis on their last input, it has a stronger influence on the intent than earlier inputs in chat_history - include all entities such as places Intent format: diff --git a/templates/openai_assistant_prompt.jinja2 b/templates/openai_assistant_prompt.jinja2 index 9c399b20..26adf007 100644 --- a/templates/openai_assistant_prompt.jinja2 +++ b/templates/openai_assistant_prompt.jinja2 @@ -100,7 +100,6 @@ WHERE Conversely, if you do not exclude the aggregate data, you will get a mix of aggregated and disaggregated data. - 4. HDX Shape files NEVER query the database for shape files, they are too large. diff --git a/ui/chat-chainlit-assistant/app.py b/ui/chat-chainlit-assistant/app.py index 210354e3..6d87575f 100644 --- a/ui/chat-chainlit-assistant/app.py +++ b/ui/chat-chainlit-assistant/app.py @@ -59,7 +59,7 @@ config.ui.name = bot_name -class EventHandler(AsyncAssistantEventHandler): +class EventHandler(AssistantEventHandler): def __init__(self, assistant_name: str) -> None: """ @@ -75,185 +75,142 @@ def __init__(self, assistant_name: str) -> None: self.current_message: cl.Message = None self.current_step: cl.Step = None self.current_tool_call = None + self.current_message_text = "" self.assistant_name = assistant_name - async def on_text_created(self, text) -> None: + @override + def on_event(self, event): """ - Handles the event when a new text is created. + Handles the incoming event and performs the necessary actions based on the event type. Args: - text: The newly created text. + event: The event object containing information about the event. Returns: None """ - self.current_message = await cl.Message( - author=self.assistant_name, content="" - ).send() - - async def on_text_delta(self, delta, snapshot): - """ - Handles the text delta event. - - Parameters: - - delta: The text delta object. - - snapshot: The current snapshot of the document. - - Returns: - - None - """ - if delta.value is not None: - await self.current_message.stream_token(delta.value) + print(event.event) + run_id = event.data.id + if event.event == "thread.message.created": + self.current_message = run_sync(cl.Message(content="").send()) + self.current_message_text = "" + print("Run started") + if event.event == "thread.message.completed": + self.handle_message_completed(event.data, run_id) + elif event.event == "thread.run.requires_action": + self.handle_requires_action(event.data, run_id) + elif event.event == "thread.message.delta": + self.handle_message_delta(event.data) + else: + print(event.data) + print(f"Unhandled event: {event.event}") - async def on_text_done(self, text): + def handle_message_delta(self, data): """ - Callback method called when text input is done. + Handles the message delta data. Args: - text (str): The text input provided by the user. + data: The message delta data. Returns: None """ - await self.current_message.update() - - async def on_tool_call_created(self, tool_call): - """ - Callback method called when a tool call is created. - - Args: - tool_call: The tool call object representing the created tool call. - """ - self.current_tool_call = tool_call.id - self.current_step = cl.Step(name=tool_call.type, type="tool") - self.current_step.language = "python" - self.current_step.created_at = utc_now() - await self.current_step.send() + for content in data.delta.content: + if content.type == "text": + content = content.text.value + self.current_message_text += content + run_sync(self.current_message.stream_token(content)) + elif content.type == "image_file": + file_id = content.image_file.file_id + image_data = sync_openai_client.files.content(file_id) + image_data_bytes = image_data.read() + png_file = f"{images_loc}{file_id}.png" + print(f"Writing image to {png_file}") + with open(png_file, "wb") as file: + file.write(image_data_bytes) + image = cl.Image(path=png_file, display="inline", size="large") + print(f"Image: {png_file}") + if not self.current_message.elements: + self.current_message.elements = [] + self.current_message.elements.append(image) + run_sync(self.current_message.update()) + else: + print(f"Unhandled delta type: {content.type}") - async def on_tool_call_delta(self, delta, snapshot): + def handle_message_completed(self, data, run_id): """ - Handles the tool call delta event. + Handles the completion of a message. Args: - delta (ToolCallDelta): The delta object representing the tool call event. - snapshot (Snapshot): The snapshot object representing the current state. + data: The data associated with the completed message. + run_id: The ID of the message run. Returns: None """ - print(f"Tool call delta: {delta.type}") - if snapshot.id != self.current_tool_call: - self.current_tool_call = snapshot.id - self.current_step = cl.Step(name=delta.type, type="tool") - self.current_step.language = "python" - self.current_step.start = utc_now() - await self.current_step.send() - - if delta.type == "code_interpreter": - if delta.code_interpreter.outputs: - for output in delta.code_interpreter.outputs: - if output.type == "logs": - error_step = cl.Step(name=delta.type, type="tool") - error_step.is_error = True - error_step.output = output.logs - error_step.language = "markdown" - error_step.start = self.current_step.start - error_step.end = utc_now() - await error_step.send() - else: - if delta.code_interpreter.input: - await self.current_step.stream_token(delta.code_interpreter.input) + # Add footer to self message. We have to start a new message so it's in right order + # TODO combine streaming with image and footer + run_sync(self.current_message.update()) + self.current_message = run_sync(cl.Message(content="").send()) + + word_count = len(self.current_message_text.split()) + if word_count > 10: + run_sync(self.current_message.stream_token(llm_footer)) + run_sync(self.current_message.update()) - async def on_tool_call_done(self, tool_call): + def handle_requires_action(self, data, run_id): """ - Callback method called when a tool call is done. + Handles the required action by executing the specified tools and submitting the tool outputs. Args: - tool_call: The tool call object representing the completed tool call. + data: The data containing the required action information. + run_id: The ID of the current run. Returns: None """ - print("Tool call done!") - # Turning this off, analysis would stop suddenly - self.current_step.end = utc_now() - await self.current_step.update() + tool_outputs = [] - async def on_image_file_done(self, image_file): - """ - Callback function called when an image file is done processing. - - Args: - image_file: The image file object that has finished processing. + for tool in data.required_action.submit_tool_outputs.tool_calls: + print(tool) - Returns: - None - """ - image_id = image_file.file_id - response = await async_openai_client.files.with_raw_response.content(image_id) - image_element = cl.Image( - name=image_id, content=response.content, display="inline", size="large" - ) - if not self.current_message.elements: - self.current_message.elements = [] - self.current_message.elements.append(image_element) - await self.current_message.update() + function_name = tool.function.name + function_args = tool.function.arguments - async def on_end( - self, - ): - print("\n end assistant > ", self.current_run_step_snapshot) + function_output = run_function(function_name, function_args) - async def on_exception(self, exception: Exception) -> None: - print("\n Exception > ", self.current_run_step_snapshot) + tool_outputs.append({"tool_call_id": tool.id, "output": function_output}) - async def on_timeout(self) -> None: - print("\n Timeout > ", self.current_run_step_snapshot) + print("TOOL OUTPUTS: ") - @override - async def on_event(self, event): - # Retrieve events that are denoted with 'requires_action' - # since these will have our tool_calls - print(event.event) - if event.event == "thread.run.requires_action": - tool_outputs = [] - for tool in event.data.required_action.submit_tool_outputs.tool_calls: - print(tool) + print(tool_outputs) - function_name = tool.function.name - function_args = tool.function.arguments + # Submit all tool_outputs at the same time + self.submit_tool_outputs(tool_outputs, run_id) - function_output = asyncio.run( - run_function(function_name, function_args) - ) + def submit_tool_outputs(self, tool_outputs, run_id): + """ + Submits the tool outputs to the current run. - tool_outputs.append( - {"tool_call_id": tool.id, "output": function_output} - ) + Args: + tool_outputs (list): A list of tool outputs to be submitted. + run_id (str): The ID of the current run. - print("TOOL OUTPUTS: ") - - print(tool_outputs) - - # Streaming - with sync_openai_client.beta.threads.runs.submit_tool_outputs_stream( - thread_id=self.current_run.thread_id, - run_id=self.current_run.id, - tool_outputs=tool_outputs, - event_handler=AssistantEventHandler(), - ) as stream: - print("Tool output submitted successfully") - msg = await cl.Message(author=self.assistant_name, content="").send() - response = "" - for text in stream.text_deltas: - response += text - await msg.stream_token(text) - if len(response) > 5: - await msg.stream_token(llm_footer) - await msg.update() - - -async def run_function(function_name, function_args): + Returns: + None + """ + with sync_openai_client.beta.threads.runs.submit_tool_outputs_stream( + thread_id=self.current_run.thread_id, + run_id=self.current_run.id, + tool_outputs=tool_outputs, + event_handler=EventHandler(assistant_name=self.assistant_name), + ) as stream: + # Needs this line, or it doesn't work! :) + for text in stream.text_deltas: + print(text, end="", flush=True) + + +def run_function(function_name, function_args): """ Run a function with the given name and arguments. @@ -283,7 +240,7 @@ async def run_function(function_name, function_args): return output -def print(*tup): +def print_to_log(*tup): """ Custom print function that logs the output using the logger. @@ -624,13 +581,12 @@ async def main(message: cl.Message): # Create and Stream a Run print(f"Creating and streaming a run {assistant.id}") - async with async_openai_client.beta.threads.runs.stream( + with sync_openai_client.beta.threads.runs.stream( thread_id=thread_id, assistant_id=assistant.id, event_handler=EventHandler(assistant_name=assistant.name), - # max_completion_tokens=20000 ) as stream: - await stream.until_done() + stream.until_done() @cl.on_audio_chunk