Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

修复chat_channel配置参数取值错误bug,增加ContextType.IMAGE支持,使其集成Midjourney插件时可以图生文和图生图; 优化dingtalk_channel回复打字机效果流式 AI卡片、dingtalk_message图片或富文本消息接收。 #1994

Merged
merged 6 commits into from
Jun 4, 2024
39 changes: 19 additions & 20 deletions channel/chat_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,14 @@ def _compose_context(self, ctype: ContextType, content, **kwargs):
if e_context.is_pass() or context is None:
return context
if cmsg.from_user_id == self.user_id and not config.get("trigger_by_self", True):
logger.debug("[WX]self message skipped")
logger.debug("[chat_channel]self message skipped")
return None

# 消息内容匹配过程,并处理content
if ctype == ContextType.TEXT:
if first_in and "」\n- - - - - - -" in content: # 初次匹配 过滤引用消息
logger.debug(content)
logger.debug("[WX]reference query skipped")
logger.debug("[chat_channel]reference query skipped")
return None

nick_name_black_list = conf().get("nick_name_black_list", [])
Expand All @@ -111,10 +111,10 @@ def _compose_context(self, ctype: ContextType, content, **kwargs):
nick_name = context["msg"].actual_user_nickname
if nick_name and nick_name in nick_name_black_list:
# 黑名单过滤
logger.warning(f"[WX] Nickname {nick_name} in In BlackList, ignore")
logger.warning(f"[chat_channel] Nickname {nick_name} in In BlackList, ignore")
return None

logger.info("[WX]receive group at")
logger.info("[chat_channel]receive group at")
if not conf().get("group_at_off", False):
flag = True
pattern = f"@{re.escape(self.name)}(\u2005|\u0020)"
Expand All @@ -130,13 +130,13 @@ def _compose_context(self, ctype: ContextType, content, **kwargs):
content = subtract_res
if not flag:
if context["origin_ctype"] == ContextType.VOICE:
logger.info("[WX]receive group voice, but checkprefix didn't match")
logger.info("[chat_channel]receive group voice, but checkprefix didn't match")
return None
else: # 单聊
nick_name = context["msg"].from_user_nickname
if nick_name and nick_name in nick_name_black_list:
# 黑名单过滤
logger.warning(f"[WX] Nickname '{nick_name}' in In BlackList, ignore")
logger.warning(f"[chat_channel] Nickname '{nick_name}' in In BlackList, ignore")
return None

match_prefix = check_prefix(content, conf().get("single_chat_prefix", [""]))
Expand All @@ -147,7 +147,7 @@ def _compose_context(self, ctype: ContextType, content, **kwargs):
else:
return None
content = content.strip()
img_match_prefix = check_prefix(content, conf().get("image_create_prefix"))
img_match_prefix = check_prefix(content, conf().get("image_create_prefix",[""]))
if img_match_prefix:
content = content.replace(img_match_prefix, "", 1)
context.type = ContextType.IMAGE_CREATE
Expand All @@ -159,17 +159,16 @@ def _compose_context(self, ctype: ContextType, content, **kwargs):
elif context.type == ContextType.VOICE:
if "desire_rtype" not in context and conf().get("voice_reply_voice") and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE:
context["desire_rtype"] = ReplyType.VOICE

return context

def _handle(self, context: Context):
if context is None or not context.content:
return
logger.debug("[WX] ready to handle context: {}".format(context))
logger.debug("[chat_channel] ready to handle context: {}".format(context))
# reply的构建步骤
reply = self._generate_reply(context)

logger.debug("[WX] ready to decorate reply: {}".format(reply))
logger.debug("[chat_channel] ready to decorate reply: {}".format(reply))

# reply的包装步骤
if reply and reply.content:
Expand All @@ -187,7 +186,7 @@ def _generate_reply(self, context: Context, reply: Reply = Reply()) -> Reply:
)
reply = e_context["reply"]
if not e_context.is_pass():
logger.debug("[WX] ready to handle context: type={}, content={}".format(context.type, context.content))
logger.debug("[chat_channel] ready to handle context: type={}, content={}".format(context.type, context.content))
if context.type == ContextType.TEXT or context.type == ContextType.IMAGE_CREATE: # 文字和图片消息
context["channel"] = e_context["channel"]
reply = super().build_reply_content(context.content, context)
Expand All @@ -199,7 +198,7 @@ def _generate_reply(self, context: Context, reply: Reply = Reply()) -> Reply:
try:
any_to_wav(file_path, wav_path)
except Exception as e: # 转换失败,直接使用mp3,对于某些api,mp3也可以识别
logger.warning("[WX]any to wav error, use raw path. " + str(e))
logger.warning("[chat_channel]any to wav error, use raw path. " + str(e))
wav_path = file_path
# 语音识别
reply = super().build_voice_to_text(wav_path)
Expand All @@ -210,7 +209,7 @@ def _generate_reply(self, context: Context, reply: Reply = Reply()) -> Reply:
os.remove(wav_path)
except Exception as e:
pass
# logger.warning("[WX]delete temp file error: " + str(e))
# logger.warning("[chat_channel]delete temp file error: " + str(e))

if reply.type == ReplyType.TEXT:
new_context = self._compose_context(ContextType.TEXT, reply.content, **context.kwargs)
Expand All @@ -228,7 +227,7 @@ def _generate_reply(self, context: Context, reply: Reply = Reply()) -> Reply:
elif context.type == ContextType.FUNCTION or context.type == ContextType.FILE: # 文件消息及函数调用等,当前无默认逻辑
pass
else:
logger.warning("[WX] unknown context type: {}".format(context.type))
logger.warning("[chat_channel] unknown context type: {}".format(context.type))
return
return reply

Expand All @@ -244,7 +243,7 @@ def _decorate_reply(self, context: Context, reply: Reply) -> Reply:
desire_rtype = context.get("desire_rtype")
if not e_context.is_pass() and reply and reply.type:
if reply.type in self.NOT_SUPPORT_REPLYTYPE:
logger.error("[WX]reply type not support: " + str(reply.type))
logger.error("[chat_channel]reply type not support: " + str(reply.type))
reply.type = ReplyType.ERROR
reply.content = "不支持发送的消息类型: " + str(reply.type)

Expand All @@ -265,10 +264,10 @@ def _decorate_reply(self, context: Context, reply: Reply) -> Reply:
elif reply.type == ReplyType.IMAGE_URL or reply.type == ReplyType.VOICE or reply.type == ReplyType.IMAGE or reply.type == ReplyType.FILE or reply.type == ReplyType.VIDEO or reply.type == ReplyType.VIDEO_URL:
pass
else:
logger.error("[WX] unknown reply type: {}".format(reply.type))
logger.error("[chat_channel] unknown reply type: {}".format(reply.type))
return
if desire_rtype and desire_rtype != reply.type and reply.type not in [ReplyType.ERROR, ReplyType.INFO]:
logger.warning("[WX] desire_rtype: {}, but reply type: {}".format(context.get("desire_rtype"), reply.type))
logger.warning("[chat_channel] desire_rtype: {}, but reply type: {}".format(context.get("desire_rtype"), reply.type))
return reply

def _send_reply(self, context: Context, reply: Reply):
Expand All @@ -281,14 +280,14 @@ def _send_reply(self, context: Context, reply: Reply):
)
reply = e_context["reply"]
if not e_context.is_pass() and reply and reply.type:
logger.debug("[WX] ready to send reply: {}, context: {}".format(reply, context))
logger.debug("[chat_channel] ready to send reply: {}, context: {}".format(reply, context))
self._send(reply, context)

def _send(self, reply: Reply, context: Context, retry_cnt=0):
try:
self.send(reply, context)
except Exception as e:
logger.error("[WX] sendMsg error: {}".format(str(e)))
logger.error("[chat_channel] sendMsg error: {}".format(str(e)))
if isinstance(e, NotImplementedError):
return
logger.exception(e)
Expand Down Expand Up @@ -342,7 +341,7 @@ def consume(self):
if semaphore.acquire(blocking=False): # 等线程处理完毕才能删除
if not context_queue.empty():
context = context_queue.get()
logger.debug("[WX] consume context: {}".format(context))
logger.debug("[chat_channel] consume context: {}".format(context))
future: Future = handler_pool.submit(self._handle, context)
future.add_done_callback(self._thread_pool_callback(session_id, context=context))
if session_id not in self.futures:
Expand Down
188 changes: 154 additions & 34 deletions channel/dingtalk/dingtalk_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,81 @@
@author huiwen
@Date 2023/11/28
"""

import copy
import json
# -*- coding=utf-8 -*-
import logging
import time

import dingtalk_stream
from dingtalk_stream import AckMessage
from dingtalk_stream.card_replier import AICardReplier
from dingtalk_stream.card_replier import AICardStatus
from dingtalk_stream.card_replier import CardReplier

from bridge.context import Context, ContextType
from bridge.reply import Reply, ReplyType
from channel.chat_channel import ChatChannel
from channel.dingtalk.dingtalk_message import DingTalkMessage
from bridge.context import Context
from bridge.reply import Reply
from common.expired_dict import ExpiredDict
from common.log import logger
from common.singleton import singleton
from common.time_check import time_checker
from config import conf
from common.expired_dict import ExpiredDict
from bridge.context import ContextType
from channel.chat_channel import ChatChannel
import logging
from dingtalk_stream import AckMessage
import dingtalk_stream


class CustomAICardReplier(CardReplier):
def __init__(self, dingtalk_client, incoming_message):
super(AICardReplier, self).__init__(dingtalk_client, incoming_message)

def start(
self,
card_template_id: str,
card_data: dict,
recipients: list = None,
support_forward: bool = True,
) -> str:
"""
AI卡片的创建接口
:param support_forward:
:param recipients:
:param card_template_id:
:param card_data:
:return:
"""
card_data_with_status = copy.deepcopy(card_data)
card_data_with_status["flowStatus"] = AICardStatus.PROCESSING
return self.create_and_send_card(
card_template_id,
card_data_with_status,
at_sender=True,
at_all=False,
recipients=recipients,
support_forward=support_forward,
)


# 对 AICardReplier 进行猴子补丁
AICardReplier.start = CustomAICardReplier.start


def _check(func):
def wrapper(self, cmsg: DingTalkMessage):
msgId = cmsg.msg_id
if msgId in self.receivedMsgs:
logger.info("DingTalk message {} already received, ignore".format(msgId))
return
self.receivedMsgs[msgId] = True
create_time = cmsg.create_time # 消息时间戳
if conf().get("hot_reload") == True and int(create_time) < int(time.time()) - 60: # 跳过1分钟前的历史消息
logger.debug("[DingTalk] History message {} skipped".format(msgId))
return
if cmsg.my_msg and not cmsg.is_group:
logger.debug("[DingTalk] My message {} skipped".format(msgId))
return
return func(self, cmsg)

return wrapper


@singleton
Expand All @@ -39,62 +100,121 @@ def __init__(self):
super(dingtalk_stream.ChatbotHandler, self).__init__()
self.logger = self.setup_logger()
# 历史消息id暂存,用于幂等控制
self.receivedMsgs = ExpiredDict(60 * 60 * 7.1)
logger.info("[dingtalk] client_id={}, client_secret={} ".format(
self.receivedMsgs = ExpiredDict(conf().get("expires_in_seconds"))
logger.info("[DingTalk] client_id={}, client_secret={} ".format(
self.dingtalk_client_id, self.dingtalk_client_secret))
# 无需群校验和前缀
conf()["group_name_white_list"] = ["ALL_GROUP"]
# 单聊无需前缀
conf()["single_chat_prefix"] = [""]

def startup(self):
credential = dingtalk_stream.Credential(self.dingtalk_client_id, self.dingtalk_client_secret)
client = dingtalk_stream.DingTalkStreamClient(credential)
client.register_callback_handler(dingtalk_stream.chatbot.ChatbotMessage.TOPIC, self)
client.start_forever()

async def process(self, callback: dingtalk_stream.CallbackMessage):
try:
incoming_message = dingtalk_stream.ChatbotMessage.from_dict(callback.data)
image_download_handler = self # 传入方法所在的类实例
dingtalk_msg = DingTalkMessage(incoming_message, image_download_handler)

if dingtalk_msg.is_group:
self.handle_group(dingtalk_msg)
else:
self.handle_single(dingtalk_msg)
return AckMessage.STATUS_OK, 'OK'
except Exception as e:
logger.error(f"dingtalk process error={e}")
return AckMessage.STATUS_SYSTEM_EXCEPTION, 'ERROR'

@time_checker
@_check
def handle_single(self, cmsg: DingTalkMessage):
# 处理单聊消息
if cmsg.ctype == ContextType.VOICE:
logger.debug("[dingtalk]receive voice msg: {}".format(cmsg.content))
logger.debug("[DingTalk]receive voice msg: {}".format(cmsg.content))
elif cmsg.ctype == ContextType.IMAGE:
logger.debug("[dingtalk]receive image msg: {}".format(cmsg.content))
logger.debug("[DingTalk]receive image msg: {}".format(cmsg.content))
elif cmsg.ctype == ContextType.IMAGE_CREATE:
logger.debug("[DingTalk]receive image create msg: {}".format(cmsg.content))
elif cmsg.ctype == ContextType.PATPAT:
logger.debug("[dingtalk]receive patpat msg: {}".format(cmsg.content))
logger.debug("[DingTalk]receive patpat msg: {}".format(cmsg.content))
elif cmsg.ctype == ContextType.TEXT:
expression = cmsg.my_msg
cmsg.content = conf()["single_chat_prefix"][0] + cmsg.content
logger.debug("[DingTalk]receive text msg: {}".format(cmsg.content))
else:
logger.debug("[DingTalk]receive other msg: {}".format(cmsg.content))
context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=False, msg=cmsg)
if context:
self.produce(context)


@time_checker
@_check
def handle_group(self, cmsg: DingTalkMessage):
# 处理群聊消息
if cmsg.ctype == ContextType.VOICE:
logger.debug("[dingtalk]receive voice msg: {}".format(cmsg.content))
logger.debug("[DingTalk]receive voice msg: {}".format(cmsg.content))
elif cmsg.ctype == ContextType.IMAGE:
logger.debug("[dingtalk]receive image msg: {}".format(cmsg.content))
logger.debug("[DingTalk]receive image msg: {}".format(cmsg.content))
elif cmsg.ctype == ContextType.IMAGE_CREATE:
logger.debug("[DingTalk]receive image create msg: {}".format(cmsg.content))
elif cmsg.ctype == ContextType.PATPAT:
logger.debug("[dingtalk]receive patpat msg: {}".format(cmsg.content))
logger.debug("[DingTalk]receive patpat msg: {}".format(cmsg.content))
elif cmsg.ctype == ContextType.TEXT:
expression = cmsg.my_msg
cmsg.content = conf()["group_chat_prefix"][0] + cmsg.content
logger.debug("[DingTalk]receive patpat msg: {}".format(cmsg.content))
else:
logger.debug("[DingTalk]receive other msg: {}".format(cmsg.content))
context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=True, msg=cmsg)
context['no_need_at'] = True
if context:
self.produce(context)

async def process(self, callback: dingtalk_stream.CallbackMessage):
try:
incoming_message = dingtalk_stream.ChatbotMessage.from_dict(callback.data)
dingtalk_msg = DingTalkMessage(incoming_message)
if incoming_message.conversation_type == '1':
self.handle_single(dingtalk_msg)
else:
self.handle_group(dingtalk_msg)
return AckMessage.STATUS_OK, 'OK'
except Exception as e:
logger.error(e)
return self.FAILED_MSG

def send(self, reply: Reply, context: Context):
receiver = context["receiver"]
isgroup = context.kwargs['msg'].is_group
incoming_message = context.kwargs['msg'].incoming_message
self.reply_text(reply.content, incoming_message)
logger.info("[Dingtalk] sendMsg={}, receiver={}".format(reply, receiver))
def reply_with_text():
self.reply_text(reply.content, incoming_message)
def reply_with_at_text():
self.reply_text("📢 您有一条新的消息,请查看。", incoming_message)
def reply_with_ai_markdown():
button_list, markdown_content = self.generate_button_markdown_content(context, reply)
self.reply_ai_markdown_button(incoming_message, markdown_content, button_list, "", "📌 内容由AI-Bot生成", "",[incoming_message.sender_staff_id])

if reply.type in [ReplyType.IMAGE_URL, ReplyType.IMAGE, ReplyType.TEXT]:
if isgroup:
reply_with_ai_markdown()
reply_with_at_text()
else:
reply_with_ai_markdown()
else:
# 暂不支持其它类型消息回复
reply_with_text()

def generate_button_markdown_content(self, context, reply):
image_url = context.kwargs.get("image_url")
promptEn = context.kwargs.get("promptEn")
reply_text = reply.content
button_list = []
markdown_content = f"""
{reply.content}
"""
if image_url is not None and promptEn is not None:
button_list = [
{"text": "查看原图", "url": image_url, "iosUrl": image_url, "color": "blue"}
]
markdown_content = f"""
{promptEn}

!["图片"]({image_url})

{reply_text}

"""
logger.debug(f"[Dingtalk] generate_button_markdown_content, button_list={button_list} , markdown_content={markdown_content}")

return button_list, markdown_content
Loading