diff --git a/bot.py b/bot.py index c3d38e18..024107e5 100644 --- a/bot.py +++ b/bot.py @@ -23,7 +23,8 @@ reset_ENGINE, update_language, get_robot, - get_image_message + get_image_message, + get_ENGINE ) from utils.i18n import strings @@ -127,12 +128,15 @@ async def command_bot(update, context, language=None, prompt=translator_prompt, elif reply_to_message_text and not update_message.reply_to_message.from_user.is_bot: message = reply_to_message_text + "\n" + message - robot, role = get_robot() - if "gpt" in config.GPT_ENGINE or (config.CLAUDE_API and "claude-3" in config.GPT_ENGINE): + if robot is None: + robot, role = get_robot(chatid) + engine = get_ENGINE(chatid) + if "gpt" in engine or (config.CLAUDE_API and "claude-3" in engine): message = [{"type": "text", "text": message}] - message = get_image_message(image_url, message) + message = get_image_message(image_url, message, chatid) # print("robot", robot) await context.bot.send_chat_action(chat_id=chatid, action=ChatAction.TYPING) + title = f"`🤖️ {engine}`\n\n" await getChatGPT(update, context, title, robot, message, chatid, messageid) else: message = await context.bot.send_message( @@ -283,16 +287,17 @@ async def delete_message(update, context, messageid, delay=10): @decorators.Authorization async def button_press(update, context): """Function to handle the button press""" - info_message = update_info_message() callback_query = update.callback_query + chatid = callback_query.message.chat_id + info_message = update_info_message(chatid) await callback_query.answer() data = callback_query.data banner = strings['message_banner'][get_current_lang()] if data.endswith("ENGINE"): data = data[:-6] - update_ENGINE(data) + update_ENGINE(data, chatid) try: - info_message = update_info_message() + info_message = update_info_message(chatid) if info_message + banner != callback_query.message.text: message = await callback_query.edit_message_text( text=escape(info_message + banner), @@ -317,7 +322,7 @@ async def button_press(update, context): elif "language" in data: update_language() update_ENGINE() - info_message = update_info_message() + info_message = update_info_message(chatid) message = await callback_query.edit_message_text( text=escape(info_message), reply_markup=InlineKeyboardMarkup(update_first_buttons_message()), @@ -328,7 +333,7 @@ async def button_press(update, context): PLUGINS[data] = not PLUGINS[data] except: setattr(config, data, not getattr(config, data)) - info_message = update_info_message() + info_message = update_info_message(chatid) message = await callback_query.edit_message_text( text=escape(info_message), reply_markup=InlineKeyboardMarkup(update_first_buttons_message()), @@ -339,7 +344,8 @@ async def button_press(update, context): @decorators.GroupAuthorization @decorators.Authorization async def info(update, context): - info_message = update_info_message() + chatid = update.message.chat_id + info_message = update_info_message(chatid) message = await context.bot.send_message(chat_id=update.message.chat_id, text=escape(info_message), reply_markup=InlineKeyboardMarkup(update_first_buttons_message()), parse_mode='MarkdownV2', disable_web_page_preview=True) @decorators.GroupAuthorization @@ -354,7 +360,9 @@ async def handle_pdf(update, context): extracted_text_with_prompt = Document_extract(file_url) robot, role = get_robot() robot.add_to_conversation(extracted_text_with_prompt, role, str(update.effective_chat.id)) - if config.CLAUDE_API and "claude-3" in config.GPT_ENGINE: + chatid = update.message.chat_id + engine = get_ENGINE(chatid) + if config.CLAUDE_API and "claude-3" in engine: robot.add_to_conversation(claude3_doc_assistant_prompt, "assistant", str(update.effective_chat.id)) message = ( f"文档上传成功!\n\n" @@ -378,7 +386,7 @@ async def handle_photo(update, context): image_url = photo_file.file_path robot, role = get_robot() - message = get_image_message(image_url, []) + message = get_image_message(image_url, [], chatid) robot.add_to_conversation(message, role, str(chatid)) # if config.CLAUDE_API and "claude-3" in config.GPT_ENGINE: @@ -473,7 +481,7 @@ async def process_update(update): application.add_handler(CommandHandler("start", start)) application.add_handler(CommandHandler("pic", image, block = False)) - application.add_handler(CommandHandler("search", lambda update, context: command_bot(update, context, prompt="search: ", title=f"`🤖️ {config.GPT_ENGINE}`\n\n", robot=config.ChatGPTbot, has_command="search"))) + application.add_handler(CommandHandler("search", lambda update, context: command_bot(update, context, prompt="search: ", has_command="search"))) application.add_handler(CallbackQueryHandler(button_press)) application.add_handler(CommandHandler("reset", reset_chat)) application.add_handler(CommandHandler("en2zh", lambda update, context: command_bot(update, context, "Simplified Chinese", robot=config.translate_bot))) @@ -481,8 +489,8 @@ async def process_update(update): application.add_handler(CommandHandler("info", info)) application.add_handler(InlineQueryHandler(inlinequery)) application.add_handler(MessageHandler(filters.Document.PDF | filters.Document.TXT | filters.Document.DOC, handle_pdf)) - application.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, lambda update, context: command_bot(update, context, prompt=None, title=f"`🤖️ {config.GPT_ENGINE}`\n\n", robot=config.ChatGPTbot, has_command=False))) - application.add_handler(MessageHandler(filters.CAPTION & filters.PHOTO & ~filters.COMMAND, lambda update, context: command_bot(update, context, prompt=None, title=f"`🤖️ {config.GPT_ENGINE}`\n\n", robot=config.ChatGPTbot, has_command=False))) + application.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, lambda update, context: command_bot(update, context, prompt=None, has_command=False))) + application.add_handler(MessageHandler(filters.CAPTION & filters.PHOTO & ~filters.COMMAND, lambda update, context: command_bot(update, context, prompt=None, has_command=False))) application.add_handler(MessageHandler(~filters.CAPTION & filters.PHOTO & ~filters.COMMAND, handle_photo)) application.add_handler(MessageHandler(filters.COMMAND, unknown)) application.add_error_handler(error) diff --git a/config.py b/config.py index 6515e92f..953e03cb 100644 --- a/config.py +++ b/config.py @@ -22,14 +22,6 @@ def replace_with_asterisk(string, start=10, end=45): API = os.environ.get('API', None) WEB_HOOK = os.environ.get('WEB_HOOK', None) -def update_info_message(): - return ( - f"**Model:** `{GPT_ENGINE}`\n\n" - f"**API_URL:** `{API_URL}`\n\n" - f"**API:** `{replace_with_asterisk(API)}`\n\n" - f"**WEB_HOOK:** `{WEB_HOOK}`\n\n" - ) - GROQ_API_KEY = os.environ.get('GROQ_API_KEY', None) GOOGLE_AI_API_KEY = os.environ.get('GOOGLE_AI_API_KEY', None) @@ -64,23 +56,77 @@ def update_language(): temperature = float(os.environ.get('temperature', '0.5')) CLAUDE_API = os.environ.get('claude_api_key', None) +class UserConfig: + def __init__(self, user_id: str = None, language="English", engine="gpt-4o", mode="global"): + self.user_id = user_id + self.language = language + self.engine = engine + self.users = { + "global": { + "language": self.language, + "engine": self.engine, + } + } + self.mode = mode + self.parameter_name_list = list(self.users["global"].keys()) + def user_init(self, user_id = None): + if user_id == None: + user_id = "global" + self.user_id = user_id + if self.user_id not in self.users.keys(): + self.users[self.user_id] = {"language": LANGUAGE, "engine": GPT_ENGINE} + + def get_config(self, user_id = None, parameter_name = None): + if parameter_name not in self.parameter_name_list: + raise ValueError("parameter_name is not in the parameter_name_list") + if self.mode == "global": + return self.users["global"][parameter_name] + if self.mode == "multiusers": + self.user_init(user_id) + return self.users[self.user_id][parameter_name] + + def set_config(self, user_id = None, parameter_name = None, value = None): + if parameter_name not in self.parameter_name_list: + raise ValueError("parameter_name is not in the parameter_name_list") + if self.mode == "global": + self.users["global"][parameter_name] = value + if self.mode == "multiusers": + self.user_init(user_id) + self.users[self.user_id][parameter_name] = value + +CHAT_MODE = os.environ.get('CHAT_MODE', "global") +Users = UserConfig(mode=CHAT_MODE) + +def get_ENGINE(user_id = None): + return Users.get_config(user_id, "engine") + +def update_info_message(user_id = None): + return ( + f"**Model:** `{get_ENGINE(user_id)}`\n\n" + f"**API_URL:** `{API_URL}`\n\n" + f"**API:** `{replace_with_asterisk(API)}`\n\n" + f"**WEB_HOOK:** `{WEB_HOOK}`\n\n" + ) + ChatGPTbot, translate_bot, dallbot, claudeBot, claude3Bot, groqBot, gemini_Bot = None, None, None, None, None, None, None -def update_ENGINE(data = None): - global GPT_ENGINE, ChatGPTbot, translate_bot, dallbot, claudeBot, claude3Bot, groqBot, gemini_Bot +def update_ENGINE(data = None, chat_id=None): + global Users, ChatGPTbot, translate_bot, dallbot, claudeBot, claude3Bot, groqBot, gemini_Bot if data: - GPT_ENGINE = data + Users.set_config(chat_id, "engine", data) + engine = Users.get_config(chat_id, "engine") if API: - ChatGPTbot = chatgpt(api_key=f"{API}", engine=GPT_ENGINE, system_prompt=systemprompt, temperature=temperature) - translate_bot = chatgpt(api_key=f"{API}", engine=GPT_ENGINE, system_prompt=systemprompt, temperature=temperature) + ChatGPTbot = chatgpt(api_key=f"{API}", engine=engine, system_prompt=systemprompt, temperature=temperature) + translate_bot = chatgpt(api_key=f"{API}", engine=engine, system_prompt=systemprompt, temperature=temperature) dallbot = dalle3(api_key=f"{API}") - if CLAUDE_API and "claude-2.1" in GPT_ENGINE: - claudeBot = claude(api_key=f"{CLAUDE_API}", engine=GPT_ENGINE, system_prompt=claude_systemprompt, temperature=temperature) - if CLAUDE_API and "claude-3" in GPT_ENGINE: - claude3Bot = claude3(api_key=f"{CLAUDE_API}", engine=GPT_ENGINE, system_prompt=claude_systemprompt, temperature=temperature) - if GROQ_API_KEY and ("mixtral" in GPT_ENGINE or "llama" in GPT_ENGINE): - groqBot = groq(api_key=f"{GROQ_API_KEY}", engine=GPT_ENGINE, system_prompt=systemprompt, temperature=temperature) - if GOOGLE_AI_API_KEY and "gemini" in GPT_ENGINE: - gemini_Bot = gemini(api_key=f"{GOOGLE_AI_API_KEY}", engine=GPT_ENGINE, system_prompt=systemprompt, temperature=temperature) + if CLAUDE_API and "claude-2.1" in engine: + claudeBot = claude(api_key=f"{CLAUDE_API}", engine=engine, system_prompt=claude_systemprompt, temperature=temperature) + if CLAUDE_API and "claude-3" in engine: + claude3Bot = claude3(api_key=f"{CLAUDE_API}", engine=engine, system_prompt=claude_systemprompt, temperature=temperature) + if GROQ_API_KEY and ("mixtral" in engine or "llama" in engine): + groqBot = groq(api_key=f"{GROQ_API_KEY}", engine=engine, system_prompt=systemprompt, temperature=temperature) + if GOOGLE_AI_API_KEY and "gemini" in engine: + gemini_Bot = gemini(api_key=f"{GOOGLE_AI_API_KEY}", engine=engine, system_prompt=systemprompt, temperature=temperature) + update_ENGINE() def reset_ENGINE(chat_id, message=None): @@ -99,17 +145,18 @@ def reset_ENGINE(chat_id, message=None): if GOOGLE_AI_API_KEY and gemini_Bot: gemini_Bot.reset(convo_id=str(chat_id), system_prompt=systemprompt) -def get_robot(): +def get_robot(chat_id = None): global ChatGPTbot, claudeBot, claude3Bot, groqBot, gemini_Bot - if CLAUDE_API and "claude-2.1" in GPT_ENGINE: + engine = Users.get_config(chat_id, "engine") + if CLAUDE_API and "claude-2.1" in engine: robot = claudeBot role = "Human" - elif CLAUDE_API and "claude-3" in GPT_ENGINE: + elif CLAUDE_API and "claude-3" in engine: robot = claude3Bot role = "user" - elif ("mixtral" in GPT_ENGINE or "llama" in GPT_ENGINE) and GROQ_API_KEY: + elif ("mixtral" in engine or "llama" in engine) and GROQ_API_KEY: robot = groqBot - elif GOOGLE_AI_API_KEY and "gemini" in GPT_ENGINE: + elif GOOGLE_AI_API_KEY and "gemini" in engine: robot = gemini_Bot role = "user" else: @@ -118,10 +165,11 @@ def get_robot(): return robot, role -def get_image_message(image_url, message): +def get_image_message(image_url, message, chatid = None): + engine = get_ENGINE(chatid) if image_url: base64_image = get_encode_image(image_url) - if "gpt-4" in GPT_ENGINE or (CLAUDE_API is None and "claude-3" in GPT_ENGINE): + if "gpt-4" in engine or (CLAUDE_API is None and "claude-3" in engine): message.append( { "type": "image_url", @@ -130,7 +178,7 @@ def get_image_message(image_url, message): } } ) - if CLAUDE_API and "claude-3" in GPT_ENGINE: + if CLAUDE_API and "claude-3" in engine: message.append( { "type": "image", @@ -153,36 +201,6 @@ def get_image_message(image_url, message): if GROUP_LIST: GROUP_LIST = [int(id) for id in GROUP_LIST.split(",")] -class UserConfig: - def __init__(self, user_id: str = "default", language="English", engine="gpt-4o"): - self.user_id = user_id - self.language = language - self.engine = engine - self.users = { - "default": { - "language": self.language, - "engine": self.engine, - } - } - def user_init(self, user_id): - if user_id not in self.users: - self.users[user_id] = {"language": LANGUAGE, "engine": GPT_ENGINE} - def get_language(self, user_id): - self.user_init(user_id) - return self.users[user_id]["language"] - def set_language(self, user_id, language): - self.user_init(user_id) - self.users[user_id]["language"] = language - - def get_engine(self, user_id): - self.user_init(user_id) - return self.users[user_id]["engine"] - def set_engine(self, user_id, engine): - self.user_init(user_id) - self.users[user_id]["engine"] = engine - -Users = UserConfig() - def delete_model_digit_tail(lst): for i in range(len(lst) - 1, -1, -1): if not lst[i].isdigit(): diff --git a/requirements.txt b/requirements.txt index e81c5e90..ad799a1e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,6 +3,6 @@ pytz python-dotenv md2tgmd==0.1.9 fake_useragent -ModelMerge==0.4.1 +ModelMerge==0.4.2 oauth2client==3.0.0 python-telegram-bot[webhooks,rate-limiter]==21.0.1 \ No newline at end of file