Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

主要是更新语音识别和兼容性 #634

Merged
merged 11 commits into from
Mar 27, 2023
6 changes: 3 additions & 3 deletions app.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# encoding:utf-8

import config
from config import conf, load_config
from channel import channel_factory
from common.log import logger

Expand All @@ -9,10 +9,10 @@
def run():
try:
# load config
config.load_config()
load_config()

# create channel
channel_name='wx'
channel_name=conf().get('channel_type', 'wx')
channel = channel_factory.create_channel(channel_name)
if channel_name=='wx':
PluginManager().load_plugins()
Expand Down
6 changes: 3 additions & 3 deletions bot/bot_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@

def create_bot(bot_type):
"""
create a channel instance
:param channel_type: channel type code
:return: channel instance
create a bot_type instance
:param bot_type: bot type code
:return: bot instance
"""
if bot_type == const.BAIDU:
# Baidu Unit对话接口
Expand Down
133 changes: 85 additions & 48 deletions channel/wechat/wechat_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
"""

import os
import requests
import io
import time
from lib import itchat
import json
from lib.itchat.content import *
Expand All @@ -17,17 +20,18 @@
from config import conf
from common.time_check import time_checker
from plugins import *
import requests
import io
import time
from voice.audio_convert import mp3_to_wav


thread_pool = ThreadPoolExecutor(max_workers=8)


def thread_pool_callback(worker):
worker_exception = worker.exception()
if worker_exception:
logger.exception("Worker return exception: {}".format(worker_exception))


@itchat.msg_register(TEXT)
def handler_single_msg(msg):
WechatChannel().handle_text(msg)
Expand All @@ -48,21 +52,24 @@ def handler_group_voice(msg):
WechatChannel().handle_group_voice(msg)
return None



class WechatChannel(Channel):
def __init__(self):
self.userName = None
self.nickName = None

def startup(self):

itchat.instance.receivingRetryCount = 600 # 修改断线超时时间
itchat.instance.receivingRetryCount = 600 # 修改断线超时时间
# login by scan QRCode
hotReload = conf().get('hot_reload', False)
try:
itchat.auto_login(enableCmdQR=2, hotReload=hotReload)
except Exception as e:
if hotReload:
logger.error("Hot reload failed, try to login without hot reload")
logger.error(
"Hot reload failed, try to login without hot reload")
itchat.logout()
os.remove("itchat.pkl")
itchat.auto_login(enableCmdQR=2, hotReload=hotReload)
Expand Down Expand Up @@ -105,7 +112,8 @@ def handle_voice(self, msg):

@time_checker
def handle_text(self, msg):
logger.debug("[WX]receive text msg: " + json.dumps(msg, ensure_ascii=False))
logger.debug("[WX]receive text msg: " +
json.dumps(msg, ensure_ascii=False))
content = msg['Text']
from_user_id = msg['FromUserName']
to_user_id = msg['ToUserName'] # 接收人id
Expand All @@ -119,7 +127,7 @@ def handle_text(self, msg):
other_user_id = from_user_id
create_time = msg['CreateTime'] # 消息时间
match_prefix = check_prefix(content, conf().get('single_chat_prefix'))
if conf().get('hot_reload') == True and int(create_time) < int(time.time()) - 60: #跳过1分钟前的历史消息
if conf().get('hot_reload') == True and int(create_time) < int(time.time()) - 60: # 跳过1分钟前的历史消息
logger.debug("[WX]history message skipped")
return
if "」\n- - - - - - - - - - - - - - -" in content:
Expand All @@ -130,25 +138,29 @@ def handle_text(self, msg):
elif match_prefix is None:
return
context = Context()
context.kwargs = {'isgroup': False, 'msg': msg, 'receiver': other_user_id, 'session_id': other_user_id}
context.kwargs = {'isgroup': False, 'msg': msg,
'receiver': other_user_id, 'session_id': other_user_id}

img_match_prefix = check_prefix(content, conf().get('image_create_prefix'))
img_match_prefix = check_prefix(
content, conf().get('image_create_prefix'))
if img_match_prefix:
content = content.replace(img_match_prefix, '', 1).strip()
context.type = ContextType.IMAGE_CREATE
else:
context.type = ContextType.TEXT

context.content = content
thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback)
thread_pool.submit(self.handle, context).add_done_callback(
thread_pool_callback)

@time_checker
def handle_group(self, msg):
logger.debug("[WX]receive group msg: " + json.dumps(msg, ensure_ascii=False))
logger.debug("[WX]receive group msg: " +
json.dumps(msg, ensure_ascii=False))
group_name = msg['User'].get('NickName', None)
group_id = msg['User'].get('UserName', None)
create_time = msg['CreateTime'] # 消息时间
if conf().get('hot_reload') == True and int(create_time) < int(time.time()) - 60: #跳过1分钟前的历史消息
if conf().get('hot_reload') == True and int(create_time) < int(time.time()) - 60: # 跳过1分钟前的历史消息
logger.debug("[WX]history group message skipped")
return
if not group_name:
Expand All @@ -166,12 +178,14 @@ def handle_group(self, msg):
return ""
config = conf()
match_prefix = (msg['IsAt'] and not config.get("group_at_off", False)) or check_prefix(origin_content, config.get('group_chat_prefix')) \
or check_contain(origin_content, config.get('group_chat_keyword'))
or check_contain(origin_content, config.get('group_chat_keyword'))
if ('ALL_GROUP' in config.get('group_name_white_list') or group_name in config.get('group_name_white_list') or check_contain(group_name, config.get('group_name_keyword_white_list'))) and match_prefix:
context = Context()
context.kwargs = { 'isgroup': True, 'msg': msg, 'receiver': group_id}

img_match_prefix = check_prefix(content, conf().get('image_create_prefix'))
context.kwargs = {'isgroup': True,
'msg': msg, 'receiver': group_id}

img_match_prefix = check_prefix(
content, conf().get('image_create_prefix'))
if img_match_prefix:
content = content.replace(img_match_prefix, '', 1).strip()
context.type = ContextType.IMAGE_CREATE
Expand All @@ -187,7 +201,8 @@ def handle_group(self, msg):
else:
context['session_id'] = msg['ActualUserName']

thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback)
thread_pool.submit(self.handle, context).add_done_callback(
thread_pool_callback)

def handle_group_voice(self, msg):
if conf().get('group_speech_recognition', False) != True:
Expand Down Expand Up @@ -217,7 +232,7 @@ def handle_group_voice(self, msg):
thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback)

# 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息
def send(self, reply : Reply, receiver):
def send(self, reply: Reply, receiver):
if reply.type == ReplyType.TEXT:
itchat.send(reply.content, toUserName=receiver)
logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver))
Expand All @@ -226,17 +241,19 @@ def send(self, reply : Reply, receiver):
logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver))
elif reply.type == ReplyType.VOICE:
itchat.send_file(reply.content, toUserName=receiver)
logger.info('[WX] sendFile={}, receiver={}'.format(reply.content, receiver))
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
logger.info('[WX] sendFile={}, receiver={}'.format(
reply.content, receiver))
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
img_url = reply.content
pic_res = requests.get(img_url, stream=True)
image_storage = io.BytesIO()
for block in pic_res.iter_content(1024):
image_storage.write(block)
image_storage.seek(0)
itchat.send_image(image_storage, toUserName=receiver)
logger.info('[WX] sendImage url={}, receiver={}'.format(img_url,receiver))
elif reply.type == ReplyType.IMAGE: # 从文件读取图片
logger.info('[WX] sendImage url={}, receiver={}'.format(
img_url, receiver))
elif reply.type == ReplyType.IMAGE: # 从文件读取图片
image_storage = reply.content
image_storage.seek(0)
itchat.send_image(image_storage, toUserName=receiver)
Expand All @@ -247,32 +264,46 @@ def handle(self, context):
reply = Reply()

logger.debug('[WX] ready to handle context: {}'.format(context))

# reply的构建步骤
e_context = PluginManager().emit_event(EventContext(Event.ON_HANDLE_CONTEXT, {'channel' : self, 'context': context, 'reply': reply}))
e_context = PluginManager().emit_event(EventContext(Event.ON_HANDLE_CONTEXT, {
'channel': self, 'context': context, 'reply': reply}))
reply = e_context['reply']
if not e_context.is_pass():
logger.debug('[WX] ready to handle context: type={}, content={}'.format(context.type, context.content))
if context.type == ContextType.TEXT or context.type == ContextType.IMAGE_CREATE: # 文字和图片消息
reply = super().build_reply_content(context.content, context)
elif context.type == ContextType.VOICE: # 语音消息
msg = context['msg']
file_name = TmpDir().path() + context.content
msg.download(file_name)
reply = super().build_voice_to_text(file_name)
if reply.type == ReplyType.TEXT:
content = reply.content # 语音转文字后,将文字内容作为新的context
# 如果是群消息,判断是否触发关键字
if context['isgroup']:
match_prefix = check_prefix(content, conf().get('group_chat_prefix'))
match_contain = check_contain(content, conf().get('group_chat_keyword'))
logger.debug('[WX] group chat prefix match: {}'.format(match_prefix))
if match_prefix is None and match_contain is None:
return
mp3_path = TmpDir().path() + context.content
msg.download(mp3_path)
# mp3转wav
wav_path = os.path.splitext(mp3_path)[0] + '.wav'
mp3_to_wav(mp3_path=mp3_path, wav_path=wav_path)
# 语音识别
reply = super().build_voice_to_text(wav_path)
# 删除临时文件
os.remove(wav_path)
os.remove(mp3_path)
if reply.type != ReplyType.ERROR and reply.type != ReplyType.INFO:
content = reply.content # 语音转文字后,将文字内容作为新的context
context.type = ContextType.TEXT
if (context["isgroup"] == True):
# 校验关键字
match_prefix = check_prefix(content, conf().get('group_chat_prefix')) \
or check_contain(content, conf().get('group_chat_keyword'))
# Wechaty判断is_at为True,返回的内容是过滤掉@之后的内容;而is_at为False,则会返回完整的内容
if match_prefix is not None:
# 故判断如果匹配到自定义前缀,则返回过滤掉前缀+空格后的内容,用于实现类似自定义+前缀触发生成AI图片的功能
prefixes = conf().get('group_chat_prefix')
for prefix in prefixes:
if content.startswith(prefix):
content = content.replace(prefix, '', 1).strip()
break
else:
if match_prefix:
content = content.replace(match_prefix, '', 1).strip()
logger.info("[WX]receive voice check prefix: " + 'False')
return

img_match_prefix = check_prefix(content, conf().get('image_create_prefix'))
if img_match_prefix:
content = content.replace(img_match_prefix, '', 1).strip()
Expand All @@ -289,16 +320,19 @@ def handle(self, context):
return

logger.debug('[WX] ready to decorate reply: {}'.format(reply))

# reply的包装步骤
if reply and reply.type:
e_context = PluginManager().emit_event(EventContext(Event.ON_DECORATE_REPLY, {'channel' : self, 'context': context, 'reply': reply}))
reply=e_context['reply']
e_context = PluginManager().emit_event(EventContext(Event.ON_DECORATE_REPLY, {
'channel': self, 'context': context, 'reply': reply}))
reply = e_context['reply']
if not e_context.is_pass() and reply and reply.type:
if reply.type == ReplyType.TEXT:
reply_text = reply.content
if context['isgroup']:
reply_text = '@' + context['msg']['ActualNickName'] + ' ' + reply_text.strip()
reply_text = '@' + \
context['msg']['ActualNickName'] + \
' ' + reply_text.strip()
reply_text = conf().get("group_chat_reply_prefix", "")+reply_text
else:
reply_text = conf().get("single_chat_reply_prefix", "")+reply_text
Expand All @@ -308,15 +342,18 @@ def handle(self, context):
elif reply.type == ReplyType.IMAGE_URL or reply.type == ReplyType.VOICE or reply.type == ReplyType.IMAGE:
pass
else:
logger.error('[WX] unknown reply type: {}'.format(reply.type))
logger.error(
'[WX] unknown reply type: {}'.format(reply.type))
return

# reply的发送步骤
# reply的发送步骤
if reply and reply.type:
e_context = PluginManager().emit_event(EventContext(Event.ON_SEND_REPLY, {'channel' : self, 'context': context, 'reply': reply}))
reply=e_context['reply']
e_context = PluginManager().emit_event(EventContext(Event.ON_SEND_REPLY, {
'channel': self, 'context': context, 'reply': reply}))
reply = e_context['reply']
if not e_context.is_pass() and reply and reply.type:
logger.debug('[WX] ready to send reply: {} to {}'.format(reply, context['receiver']))
logger.debug('[WX] ready to send reply: {} to {}'.format(
reply, context['receiver']))
self.send(reply, context['receiver'])

def check_prefix(content, prefix_list):
Expand Down
Loading