From b06664f87f8bbe9b07d606ba99b84fe8a3106bd2 Mon Sep 17 00:00:00 2001 From: userpj Date: Fri, 13 Dec 2024 15:40:18 +0800 Subject: [PATCH] =?UTF-8?q?chainlit=E6=96=B0=E5=A2=9Echatflow=20agent?= =?UTF-8?q?=E6=94=AF=E6=8C=81=20(#663)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- python/core/agent.py | 33 +++++++++++++++++++++++++++++---- 1 file changed, 29 insertions(+), 4 deletions(-) diff --git a/python/core/agent.py b/python/core/agent.py index 4c248d485..d6d1df809 100644 --- a/python/core/agent.py +++ b/python/core/agent.py @@ -26,7 +26,7 @@ from appbuilder.core.component import Component from appbuilder.core.message import Message from appbuilder.utils.logger_util import logger -from appbuilder.core.console.appbuilder_client.data_class import ToolChoiceFunction, ToolChoice +from appbuilder.core.console.appbuilder_client.data_class import ToolChoiceFunction, ToolChoice, Action # 流式场景首包超时时,最大重试次数 MAX_RETRY_COUNT = 3 @@ -197,7 +197,7 @@ def run(self, message: Message, stream: bool=False): conn.close() """ - + component: Component user_session_config: Optional[Union[Any, str]] = None user_session: Optional[Any] = None @@ -556,6 +556,7 @@ def chainlit_agent(self, host='0.0.0.0', port=8091): self.prepare_chainlit_readme() conversation_ids = [] + interrupt_dict = {} def _chat(message: cl.Message): if len(conversation_ids) == 0: @@ -566,8 +567,31 @@ def _chat(message: cl.Message): file_id = self.component.upload_local_file( conversation_id, message.elements[0].path) file_ids.append(file_id) - return self.component.run(conversation_id=conversation_id, query=message.content, file_ids=file_ids, - stream=True, tool_choice=self.tool_choice) + + interrupt_ids = interrupt_dict.get(conversation_id, []) + interrupt_event_id = interrupt_ids.pop() if len(interrupt_ids) > 0 else None + action = None + if interrupt_event_id is not None: + action = Action.create_resume_action(interrupt_event_id) + + tmp_message = self.component.run(conversation_id=conversation_id, query=message.content, file_ids=file_ids, + stream=True, tool_choice=self.tool_choice, action=action) + res_message=list(tmp_message.content) + + interrupt_event_id = None + for ans in res_message: + for event in ans.events: + if event.content_type == "chatflow_interrupt": + interrupt_event_id = event.detail.get("interrupt_event_id") + if event.content_type == "publish_message" and event.event_type == "chatflow": + answer = event.detail.get("message") + ans.answer += answer + + if interrupt_event_id is not None: + interrupt_ids.append(interrupt_event_id) + interrupt_dict[conversation_id] = interrupt_ids + tmp_message.content = res_message + return tmp_message @cl.on_chat_start async def start(): @@ -575,6 +599,7 @@ async def start(): request_id = str(uuid.uuid4()) init_context(session_id=session_id, request_id=request_id) conversation_ids.append(self.component.create_conversation()) + interrupt_dict[conversation_ids[-1]] = [] @cl.on_message # this function will be called every time a user inputs a message in the UI async def main(message: cl.Message):