forked from zhayujie/chatgpt-on-wechat
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request zhayujie#2071 from lmy668/master
feat#add minmax model
- Loading branch information
Showing
6 changed files
with
246 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
# encoding:utf-8 | ||
|
||
import time | ||
|
||
import openai | ||
import openai.error | ||
from bot.bot import Bot | ||
from bot.minimax.minimax_session import MinimaxSession | ||
from bot.session_manager import SessionManager | ||
from bridge.context import Context, ContextType | ||
from bridge.reply import Reply, ReplyType | ||
from common.log import logger | ||
from config import conf, load_config | ||
from bot.chatgpt.chat_gpt_session import ChatGPTSession | ||
import requests | ||
from common import const | ||
|
||
|
||
# ZhipuAI对话模型API | ||
class MinimaxBot(Bot): | ||
def __init__(self): | ||
super().__init__() | ||
self.args = { | ||
"model": conf().get("model") or "abab6.5", # 对话模型的名称 | ||
"temperature": conf().get("temperature", 0.3), # 如果设置,值域须为 [0, 1] 我们推荐 0.3,以达到较合适的效果。 | ||
"top_p": conf().get("top_p", 0.95), # 使用默认值 | ||
} | ||
self.api_key = conf().get("Minimax_api_key") | ||
self.group_id = conf().get("Minimax_group_id") | ||
self.base_url = conf().get("Minimax_base_url", f"https://api.minimax.chat/v1/text/chatcompletion_pro?GroupId={self.group_id}") | ||
# tokens_to_generate/bot_setting/reply_constraints可自行修改 | ||
self.request_body = { | ||
"model": self.args["model"], | ||
"tokens_to_generate": 2048, | ||
"reply_constraints": {"sender_type": "BOT", "sender_name": "MM智能助理"}, | ||
"messages": [], | ||
"bot_setting": [ | ||
{ | ||
"bot_name": "MM智能助理", | ||
"content": "MM智能助理是一款由MiniMax自研的,没有调用其他产品的接口的大型语言模型。MiniMax是一家中国科技公司,一直致力于进行大模型相关的研究。", | ||
} | ||
], | ||
} | ||
self.sessions = SessionManager(MinimaxSession, model=const.MiniMax) | ||
|
||
def reply(self, query, context: Context = None) -> Reply: | ||
# acquire reply content | ||
logger.info("[Minimax_AI] query={}".format(query)) | ||
if context.type == ContextType.TEXT: | ||
session_id = context["session_id"] | ||
reply = None | ||
clear_memory_commands = conf().get("clear_memory_commands", ["#清除记忆"]) | ||
if query in clear_memory_commands: | ||
self.sessions.clear_session(session_id) | ||
reply = Reply(ReplyType.INFO, "记忆已清除") | ||
elif query == "#清除所有": | ||
self.sessions.clear_all_session() | ||
reply = Reply(ReplyType.INFO, "所有人记忆已清除") | ||
elif query == "#更新配置": | ||
load_config() | ||
reply = Reply(ReplyType.INFO, "配置已更新") | ||
if reply: | ||
return reply | ||
session = self.sessions.session_query(query, session_id) | ||
logger.debug("[Minimax_AI] session query={}".format(session)) | ||
|
||
model = context.get("Minimax_model") | ||
new_args = self.args.copy() | ||
if model: | ||
new_args["model"] = model | ||
# if context.get('stream'): | ||
# # reply in stream | ||
# return self.reply_text_stream(query, new_query, session_id) | ||
|
||
reply_content = self.reply_text(session, args=new_args) | ||
logger.debug( | ||
"[Minimax_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format( | ||
session.messages, | ||
session_id, | ||
reply_content["content"], | ||
reply_content["completion_tokens"], | ||
) | ||
) | ||
if reply_content["completion_tokens"] == 0 and len(reply_content["content"]) > 0: | ||
reply = Reply(ReplyType.ERROR, reply_content["content"]) | ||
elif reply_content["completion_tokens"] > 0: | ||
self.sessions.session_reply(reply_content["content"], session_id, reply_content["total_tokens"]) | ||
reply = Reply(ReplyType.TEXT, reply_content["content"]) | ||
else: | ||
reply = Reply(ReplyType.ERROR, reply_content["content"]) | ||
logger.debug("[Minimax_AI] reply {} used 0 tokens.".format(reply_content)) | ||
return reply | ||
else: | ||
reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type)) | ||
return reply | ||
|
||
def reply_text(self, session: MinimaxSession, args=None, retry_count=0) -> dict: | ||
""" | ||
call openai's ChatCompletion to get the answer | ||
:param session: a conversation session | ||
:param session_id: session id | ||
:param retry_count: retry count | ||
:return: {} | ||
""" | ||
try: | ||
headers = {"Content-Type": "application/json", "Authorization": "Bearer " + self.api_key} | ||
self.request_body["messages"].extend(session.messages) | ||
logger.info("[Minimax_AI] request_body={}".format(self.request_body)) | ||
# logger.info("[Minimax_AI] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"])) | ||
res = requests.post(self.base_url, headers=headers, json=self.request_body) | ||
|
||
# self.request_body["messages"].extend(response.json()["choices"][0]["messages"]) | ||
if res.status_code == 200: | ||
response = res.json() | ||
return { | ||
"total_tokens": response["usage"]["total_tokens"], | ||
"completion_tokens": response["usage"]["total_tokens"], | ||
"content": response["reply"], | ||
} | ||
else: | ||
response = res.json() | ||
error = response.get("error") | ||
logger.error(f"[Minimax_AI] chat failed, status_code={res.status_code}, " f"msg={error.get('message')}, type={error.get('type')}") | ||
|
||
result = {"completion_tokens": 0, "content": "提问太快啦,请休息一下再问我吧"} | ||
need_retry = False | ||
if res.status_code >= 500: | ||
# server error, need retry | ||
logger.warn(f"[Minimax_AI] do retry, times={retry_count}") | ||
need_retry = retry_count < 2 | ||
elif res.status_code == 401: | ||
result["content"] = "授权失败,请检查API Key是否正确" | ||
elif res.status_code == 429: | ||
result["content"] = "请求过于频繁,请稍后再试" | ||
need_retry = retry_count < 2 | ||
else: | ||
need_retry = False | ||
|
||
if need_retry: | ||
time.sleep(3) | ||
return self.reply_text(session, args, retry_count + 1) | ||
else: | ||
return result | ||
except Exception as e: | ||
logger.exception(e) | ||
need_retry = retry_count < 2 | ||
result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"} | ||
if need_retry: | ||
return self.reply_text(session, args, retry_count + 1) | ||
else: | ||
return result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
from bot.session_manager import Session | ||
from common.log import logger | ||
|
||
""" | ||
e.g. | ||
[ | ||
{"role": "system", "content": "You are a helpful assistant."}, | ||
{"role": "user", "content": "Who won the world series in 2020?"}, | ||
{"role": "assistant", "content": "The Los Angeles Dodgers won the World Series in 2020."}, | ||
{"role": "user", "content": "Where was it played?"} | ||
] | ||
""" | ||
|
||
|
||
class MinimaxSession(Session): | ||
def __init__(self, session_id, system_prompt=None, model="minimax"): | ||
super().__init__(session_id, system_prompt) | ||
self.model = model | ||
# self.reset() | ||
|
||
def add_query(self, query): | ||
user_item = {"sender_type": "USER", "sender_name": self.session_id, "text": query} | ||
self.messages.append(user_item) | ||
|
||
def add_reply(self, reply): | ||
assistant_item = {"sender_type": "BOT", "sender_name": "MM智能助理", "text": reply} | ||
self.messages.append(assistant_item) | ||
|
||
def discard_exceeding(self, max_tokens, cur_tokens=None): | ||
precise = True | ||
try: | ||
cur_tokens = self.calc_tokens() | ||
except Exception as e: | ||
precise = False | ||
if cur_tokens is None: | ||
raise e | ||
logger.debug("Exception when counting tokens precisely for query: {}".format(e)) | ||
while cur_tokens > max_tokens: | ||
if len(self.messages) > 2: | ||
self.messages.pop(1) | ||
elif len(self.messages) == 2 and self.messages[1]["sender_type"] == "BOT": | ||
self.messages.pop(1) | ||
if precise: | ||
cur_tokens = self.calc_tokens() | ||
else: | ||
cur_tokens = cur_tokens - max_tokens | ||
break | ||
elif len(self.messages) == 2 and self.messages[1]["sender_type"] == "USER": | ||
logger.warn("user message exceed max_tokens. total_tokens={}".format(cur_tokens)) | ||
break | ||
else: | ||
logger.debug("max_tokens={}, total_tokens={}, len(messages)={}".format(max_tokens, cur_tokens, len(self.messages))) | ||
break | ||
if precise: | ||
cur_tokens = self.calc_tokens() | ||
else: | ||
cur_tokens = cur_tokens - max_tokens | ||
return cur_tokens | ||
|
||
def calc_tokens(self): | ||
return num_tokens_from_messages(self.messages, self.model) | ||
|
||
|
||
def num_tokens_from_messages(messages, model): | ||
"""Returns the number of tokens used by a list of messages.""" | ||
# 官方token计算规则:"对于中文文本来说,1个token通常对应一个汉字;对于英文文本来说,1个token通常对应3至4个字母或1个单词" | ||
# 详情请产看文档:https://help.aliyun.com/document_detail/2586397.html | ||
# 目前根据字符串长度粗略估计token数,不影响正常使用 | ||
tokens = 0 | ||
for msg in messages: | ||
tokens += len(msg["text"]) | ||
return tokens |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters