diff --git a/bot.py b/bot.py index d4d3c950..fe574c4f 100644 --- a/bot.py +++ b/bot.py @@ -31,19 +31,20 @@ print("nick:", botNick) translator_prompt = "You are a translation engine, you can only translate text and cannot interpret it, and do not explain. Translate the text to {}, please do not explain any sentences, just translate or leave them as they are. this is the content you need to translate: " +@decorators.GroupAuthorization @decorators.Authorization async def command_bot(update, context, language=None, prompt=translator_prompt, title="", robot=None, has_command=True): - if not config.ALLOWPRIVATECHAT and update.message.chat.type == "private": - return if has_command == False or len(context.args) > 0: if update.edited_message: message = update.edited_message.text if config.NICK is None else update.edited_message.text[botNicKLength:].strip() if update.edited_message.text[:botNicKLength].lower() == botNick else None rawtext = update.edited_message.text - chatid = update.effective_chat.id + chatid = update.edited_message.chat_id + messageid = update.edited_message.message_id else: message = update.message.text if config.NICK is None else update.message.text[botNicKLength:].strip() if update.message.text[:botNicKLength].lower() == botNick else None rawtext = update.message.text chatid = update.message.chat_id + messageid = update.message.message_id print("\033[32m", update.effective_user.username, update.effective_user.id, rawtext, "\033[0m") if has_command: message = ' '.join(context.args) @@ -54,7 +55,7 @@ async def command_bot(update, context, language=None, prompt=translator_prompt, if "claude" in config.GPT_ENGINE and config.ClaudeAPI: robot = config.claudeBot await context.bot.send_chat_action(chat_id=chatid, action=ChatAction.TYPING) - await getChatGPT(update, context, title, robot, message, chatid) + await getChatGPT(update, context, title, robot, message, chatid, messageid) else: message = await context.bot.send_message( chat_id=chatid, @@ -63,6 +64,7 @@ async def command_bot(update, context, language=None, prompt=translator_prompt, reply_to_message_id=update.message.message_id, ) +@decorators.GroupAuthorization @decorators.Authorization async def reset_chat(update, context): if config.API: @@ -74,15 +76,11 @@ async def reset_chat(update, context): text="重置成功!", ) -async def getChatGPT(update, context, title, robot, message, chatid=None): +async def getChatGPT(update, context, title, robot, message, chatid, messageid): result = title text = message modifytime = 0 lastresult = '' - if update.edited_message: - messageid = update.edited_message.message_id - else: - messageid = update.message.message_id message = await context.bot.send_message( chat_id=chatid, @@ -97,7 +95,7 @@ async def getChatGPT(update, context, title, robot, message, chatid=None): get_answer = gpt4free.get_response try: - for data in get_answer(text, convo_id=str(update.effective_user.id), pass_history=config.PASS_HISTORY): + for data in get_answer(text, convo_id=str(chatid), pass_history=config.PASS_HISTORY): result = result + data tmpresult = result modifytime = modifytime + 1 @@ -117,7 +115,7 @@ async def getChatGPT(update, context, title, robot, message, chatid=None): traceback.print_exc() print('\033[0m') if config.API: - robot.reset(convo_id=str(update.effective_user.id), system_prompt=config.systemprompt) + robot.reset(convo_id=str(chatid), system_prompt=config.systemprompt) if "You exceeded your current quota, please check your plan and billing details." in str(e): print("OpenAI api 已过期!") await context.bot.delete_message(chat_id=chatid, message_id=messageid) @@ -130,6 +128,8 @@ async def getChatGPT(update, context, title, robot, message, chatid=None): result = re.sub(r",", ',', result) await context.bot.edit_message_text(chat_id=chatid, message_id=messageid, text=escape(result), parse_mode='MarkdownV2', disable_web_page_preview=True) +@decorators.GroupAuthorization +@decorators.Authorization async def search(update, context, title, robot): message = update.message.text if config.NICK is None else update.message.text[botNicKLength:].strip() if update.message.text[:botNicKLength].lower() == botNick else None print("\033[32m", update.effective_user.username, update.effective_user.id, update.message.text, "\033[0m") @@ -192,6 +192,8 @@ async def search(update, context, title, robot): result = re.sub(r",", ',', result) await context.bot.edit_message_text(chat_id=update.message.chat_id, message_id=messageid, text=escape(result), parse_mode='MarkdownV2', disable_web_page_preview=True) +@decorators.GroupAuthorization +@decorators.Authorization async def image(update, context): print("\033[32m", update.effective_user.username, update.effective_user.id, update.message.text, "\033[0m") if (len(context.args) == 0): @@ -307,6 +309,9 @@ def replace_with_asterisk(string, start=10, end=45): return string[:start] + '*' * (end - start) + string[end:] banner = "👇下面可以随时更改默认 gpt 模型:" +@decorators.AdminAuthorization +@decorators.GroupAuthorization +@decorators.Authorization async def button_press(update, context): """Function to handle the button press""" info_message = ( @@ -487,6 +492,7 @@ async def button_press(update, context): ) @decorators.AdminAuthorization +@decorators.GroupAuthorization @decorators.Authorization async def info(update, context): info_message = ( @@ -499,9 +505,7 @@ async def info(update, context): ) message = await context.bot.send_message(chat_id=update.message.chat_id, text=escape(info_message), reply_markup=InlineKeyboardMarkup(first_buttons), parse_mode='MarkdownV2', disable_web_page_preview=True) - # messageid = message.message_id - # await context.bot.delete_message(chat_id=update.effective_chat.id, message_id=update.message.message_id) - +@decorators.GroupAuthorization @decorators.Authorization async def handle_pdf(update, context): # 获取接收到的文件 @@ -524,6 +528,7 @@ async def handle_pdf(update, context): ) await context.bot.send_message(chat_id=update.message.chat_id, text=escape(message), parse_mode='MarkdownV2', disable_web_page_preview=True) +@decorators.GroupAuthorization @decorators.Authorization async def qa(update, context): if (len(context.args) != 2): @@ -567,6 +572,7 @@ async def error(update, context): logger.warning('Update "%s" caused error "%s"', update, context.error) await update.message.reply_text(escape("出错啦!请重试。"), parse_mode='MarkdownV2', disable_web_page_preview=True) +@decorators.GroupAuthorization @decorators.Authorization async def unknown(update, context): # 当用户输入未知命令时,返回文本 await context.bot.send_message(chat_id=update.effective_chat.id, text="Sorry, I didn't understand that command.") @@ -578,7 +584,7 @@ async def post_init(application: Application) -> None: BotCommand('search', 'search Google or duckduckgo'), BotCommand('en2zh', 'translate to Chinese'), BotCommand('zh2en', 'translate to English'), - BotCommand('qa', 'Document Q&A with Embedding Database Search'), + # BotCommand('qa', 'Document Q&A with Embedding Database Search'), BotCommand('start', 'Start the bot'), BotCommand('reset', 'Reset the bot'), ]) diff --git a/config.py b/config.py index fb67e629..f43994a1 100644 --- a/config.py +++ b/config.py @@ -15,7 +15,7 @@ API_URL = os.environ.get('API_URL', 'https://api.openai.com/v1/chat/completions') PDF_EMBEDDING = (os.environ.get('PDF_EMBEDDING', "True") == "False") == False LANGUAGE = os.environ.get('LANGUAGE', 'Simplified Chinese') -ALLOWPRIVATECHAT = (os.environ.get('ALLOWPRIVATECHAT', "True") == "False") == False + from datetime import datetime current_date = datetime.now() @@ -41,6 +41,9 @@ ADMIN_LIST = os.environ.get('ADMIN_LIST', None) if ADMIN_LIST: ADMIN_LIST = [int(id) for id in ADMIN_LIST.split(",")] +GROUP_LIST = os.environ.get('GROUP_LIST', None) +if GROUP_LIST: + GROUP_LIST = [int(id) for id in GROUP_LIST.split(",")] USE_G4F = False diff --git a/utils/decorators.py b/utils/decorators.py index 64265eff..d651b6f6 100644 --- a/utils/decorators.py +++ b/utils/decorators.py @@ -5,7 +5,23 @@ def Authorization(func): async def wrapper(*args, **kwargs): if config.whitelist == None: return await func(*args, **kwargs) - if (args[0].effective_chat.id not in config.whitelist): + if (args[0].effective_user.id not in config.whitelist): + message = ( + f"`Hi, {args[0].effective_user.username}!`\n\n" + f"id: `{args[0].effective_user.id}`\n\n" + f"无权访问!\n\n" + ) + await args[1].bot.send_message(chat_id=args[0].effective_user.id, text=message, parse_mode='MarkdownV2') + return + return await func(*args, **kwargs) + return wrapper + +# 判断是否在群聊白名单 +def GroupAuthorization(func): + async def wrapper(*args, **kwargs): + if config.GROUP_LIST == None: + return await func(*args, **kwargs) + if (args[0].effective_chat.id not in config.GROUP_LIST): message = ( f"`Hi, {args[0].effective_user.username}!`\n\n" f"id: `{args[0].effective_user.id}`\n\n" @@ -21,13 +37,13 @@ def AdminAuthorization(func): async def wrapper(*args, **kwargs): if config.ADMIN_LIST == None: return await func(*args, **kwargs) - if (args[0].effective_chat.id not in config.ADMIN_LIST): + if (args[0].effective_user.id not in config.ADMIN_LIST): message = ( f"`Hi, {args[0].effective_user.username}!`\n\n" f"id: `{args[0].effective_user.id}`\n\n" f"无权访问!\n\n" ) - await args[1].bot.send_message(chat_id=args[0].effective_chat.id, text=message, parse_mode='MarkdownV2') + await args[1].bot.send_message(chat_id=args[0].effective_user.id, text=message, parse_mode='MarkdownV2') return return await func(*args, **kwargs) return wrapper \ No newline at end of file