diff --git a/bot/gemini/google_gemini_bot.py b/bot/gemini/google_gemini_bot.py index 86eb12400..12435b03f 100644 --- a/bot/gemini/google_gemini_bot.py +++ b/bot/gemini/google_gemini_bot.py @@ -13,6 +13,7 @@ from bridge.reply import Reply, ReplyType from common.log import logger from config import conf +from bot.chatgpt.chat_gpt_session import ChatGPTSession from bot.baidu.baidu_wenxin_session import BaiduWenxinSession from google.generativeai.types import HarmCategory, HarmBlockThreshold @@ -23,8 +24,8 @@ class GoogleGeminiBot(Bot): def __init__(self): super().__init__() self.api_key = conf().get("gemini_api_key") - # 复用文心的token计算方式 - self.sessions = SessionManager(BaiduWenxinSession, model=conf().get("model") or "gpt-3.5-turbo") + # 复用chatGPT的token计算方式 + self.sessions = SessionManager(ChatGPTSession, model=conf().get("model") or "gpt-3.5-turbo") self.model = conf().get("model") or "gemini-pro" if self.model == "gemini": self.model = "gemini-pro" @@ -37,6 +38,7 @@ def reply(self, query, context: Context = None) -> Reply: session_id = context["session_id"] session = self.sessions.session_query(query, session_id) gemini_messages = self._convert_to_gemini_messages(self.filter_messages(session.messages)) + logger.debug(f"[Gemini] messages={gemini_messages}") genai.configure(api_key=self.api_key) model = genai.GenerativeModel(self.model) @@ -81,6 +83,8 @@ def _convert_to_gemini_messages(self, messages: list): role = "user" elif msg.get("role") == "assistant": role = "model" + elif msg.get("role") == "system": + role = "user" else: continue res.append({ @@ -97,7 +101,11 @@ def filter_messages(messages: list): return res for i in range(len(messages) - 1, -1, -1): message = messages[i] - if message.get("role") != turn: + role = message.get("role") + if role == "system": + res.insert(0, message) + continue + if role != turn: continue res.insert(0, message) if turn == "user":