Skip to content

Commit

Permalink
Add group whitelist environment variables
Browse files Browse the repository at this point in the history
2. Remove the option to allow private messaging, as enabling group whitelist will also disable private messaging.

3. Optimize code logic

4. Fix user whitelist authentication bug

5. Fix the authentication bug that allowed users to use the previous info command panel.
  • Loading branch information
yym68686 committed Dec 7, 2023
1 parent 4ada43e commit 6598fed
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 19 deletions.
36 changes: 21 additions & 15 deletions bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,19 +31,20 @@
print("nick:", botNick)
translator_prompt = "You are a translation engine, you can only translate text and cannot interpret it, and do not explain. Translate the text to {}, please do not explain any sentences, just translate or leave them as they are. this is the content you need to translate: "

@decorators.GroupAuthorization
@decorators.Authorization
async def command_bot(update, context, language=None, prompt=translator_prompt, title="", robot=None, has_command=True):
if not config.ALLOWPRIVATECHAT and update.message.chat.type == "private":
return
if has_command == False or len(context.args) > 0:
if update.edited_message:
message = update.edited_message.text if config.NICK is None else update.edited_message.text[botNicKLength:].strip() if update.edited_message.text[:botNicKLength].lower() == botNick else None
rawtext = update.edited_message.text
chatid = update.effective_chat.id
chatid = update.edited_message.chat_id
messageid = update.edited_message.message_id
else:
message = update.message.text if config.NICK is None else update.message.text[botNicKLength:].strip() if update.message.text[:botNicKLength].lower() == botNick else None
rawtext = update.message.text
chatid = update.message.chat_id
messageid = update.message.message_id
print("\033[32m", update.effective_user.username, update.effective_user.id, rawtext, "\033[0m")
if has_command:
message = ' '.join(context.args)
Expand All @@ -54,7 +55,7 @@ async def command_bot(update, context, language=None, prompt=translator_prompt,
if "claude" in config.GPT_ENGINE and config.ClaudeAPI:
robot = config.claudeBot
await context.bot.send_chat_action(chat_id=chatid, action=ChatAction.TYPING)
await getChatGPT(update, context, title, robot, message, chatid)
await getChatGPT(update, context, title, robot, message, chatid, messageid)
else:
message = await context.bot.send_message(
chat_id=chatid,
Expand All @@ -63,6 +64,7 @@ async def command_bot(update, context, language=None, prompt=translator_prompt,
reply_to_message_id=update.message.message_id,
)

@decorators.GroupAuthorization
@decorators.Authorization
async def reset_chat(update, context):
if config.API:
Expand All @@ -74,15 +76,11 @@ async def reset_chat(update, context):
text="重置成功!",
)

async def getChatGPT(update, context, title, robot, message, chatid=None):
async def getChatGPT(update, context, title, robot, message, chatid, messageid):
result = title
text = message
modifytime = 0
lastresult = ''
if update.edited_message:
messageid = update.edited_message.message_id
else:
messageid = update.message.message_id

message = await context.bot.send_message(
chat_id=chatid,
Expand All @@ -97,7 +95,7 @@ async def getChatGPT(update, context, title, robot, message, chatid=None):
get_answer = gpt4free.get_response

try:
for data in get_answer(text, convo_id=str(update.effective_user.id), pass_history=config.PASS_HISTORY):
for data in get_answer(text, convo_id=str(chatid), pass_history=config.PASS_HISTORY):
result = result + data
tmpresult = result
modifytime = modifytime + 1
Expand All @@ -117,7 +115,7 @@ async def getChatGPT(update, context, title, robot, message, chatid=None):
traceback.print_exc()
print('\033[0m')
if config.API:
robot.reset(convo_id=str(update.effective_user.id), system_prompt=config.systemprompt)
robot.reset(convo_id=str(chatid), system_prompt=config.systemprompt)
if "You exceeded your current quota, please check your plan and billing details." in str(e):
print("OpenAI api 已过期!")
await context.bot.delete_message(chat_id=chatid, message_id=messageid)
Expand All @@ -130,6 +128,8 @@ async def getChatGPT(update, context, title, robot, message, chatid=None):
result = re.sub(r",", ',', result)
await context.bot.edit_message_text(chat_id=chatid, message_id=messageid, text=escape(result), parse_mode='MarkdownV2', disable_web_page_preview=True)

@decorators.GroupAuthorization
@decorators.Authorization
async def search(update, context, title, robot):
message = update.message.text if config.NICK is None else update.message.text[botNicKLength:].strip() if update.message.text[:botNicKLength].lower() == botNick else None
print("\033[32m", update.effective_user.username, update.effective_user.id, update.message.text, "\033[0m")
Expand Down Expand Up @@ -192,6 +192,8 @@ async def search(update, context, title, robot):
result = re.sub(r",", ',', result)
await context.bot.edit_message_text(chat_id=update.message.chat_id, message_id=messageid, text=escape(result), parse_mode='MarkdownV2', disable_web_page_preview=True)

@decorators.GroupAuthorization
@decorators.Authorization
async def image(update, context):
print("\033[32m", update.effective_user.username, update.effective_user.id, update.message.text, "\033[0m")
if (len(context.args) == 0):
Expand Down Expand Up @@ -307,6 +309,9 @@ def replace_with_asterisk(string, start=10, end=45):
return string[:start] + '*' * (end - start) + string[end:]

banner = "👇下面可以随时更改默认 gpt 模型:"
@decorators.AdminAuthorization
@decorators.GroupAuthorization
@decorators.Authorization
async def button_press(update, context):
"""Function to handle the button press"""
info_message = (
Expand Down Expand Up @@ -487,6 +492,7 @@ async def button_press(update, context):
)

@decorators.AdminAuthorization
@decorators.GroupAuthorization
@decorators.Authorization
async def info(update, context):
info_message = (
Expand All @@ -499,9 +505,7 @@ async def info(update, context):
)
message = await context.bot.send_message(chat_id=update.message.chat_id, text=escape(info_message), reply_markup=InlineKeyboardMarkup(first_buttons), parse_mode='MarkdownV2', disable_web_page_preview=True)

# messageid = message.message_id
# await context.bot.delete_message(chat_id=update.effective_chat.id, message_id=update.message.message_id)

@decorators.GroupAuthorization
@decorators.Authorization
async def handle_pdf(update, context):
# 获取接收到的文件
Expand All @@ -524,6 +528,7 @@ async def handle_pdf(update, context):
)
await context.bot.send_message(chat_id=update.message.chat_id, text=escape(message), parse_mode='MarkdownV2', disable_web_page_preview=True)

@decorators.GroupAuthorization
@decorators.Authorization
async def qa(update, context):
if (len(context.args) != 2):
Expand Down Expand Up @@ -567,6 +572,7 @@ async def error(update, context):
logger.warning('Update "%s" caused error "%s"', update, context.error)
await update.message.reply_text(escape("出错啦!请重试。"), parse_mode='MarkdownV2', disable_web_page_preview=True)

@decorators.GroupAuthorization
@decorators.Authorization
async def unknown(update, context): # 当用户输入未知命令时,返回文本
await context.bot.send_message(chat_id=update.effective_chat.id, text="Sorry, I didn't understand that command.")
Expand All @@ -578,7 +584,7 @@ async def post_init(application: Application) -> None:
BotCommand('search', 'search Google or duckduckgo'),
BotCommand('en2zh', 'translate to Chinese'),
BotCommand('zh2en', 'translate to English'),
BotCommand('qa', 'Document Q&A with Embedding Database Search'),
# BotCommand('qa', 'Document Q&A with Embedding Database Search'),
BotCommand('start', 'Start the bot'),
BotCommand('reset', 'Reset the bot'),
])
Expand Down
5 changes: 4 additions & 1 deletion config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
API_URL = os.environ.get('API_URL', 'https://api.openai.com/v1/chat/completions')
PDF_EMBEDDING = (os.environ.get('PDF_EMBEDDING', "True") == "False") == False
LANGUAGE = os.environ.get('LANGUAGE', 'Simplified Chinese')
ALLOWPRIVATECHAT = (os.environ.get('ALLOWPRIVATECHAT', "True") == "False") == False


from datetime import datetime
current_date = datetime.now()
Expand All @@ -41,6 +41,9 @@
ADMIN_LIST = os.environ.get('ADMIN_LIST', None)
if ADMIN_LIST:
ADMIN_LIST = [int(id) for id in ADMIN_LIST.split(",")]
GROUP_LIST = os.environ.get('GROUP_LIST', None)
if GROUP_LIST:
GROUP_LIST = [int(id) for id in GROUP_LIST.split(",")]

USE_G4F = False

Expand Down
22 changes: 19 additions & 3 deletions utils/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,23 @@ def Authorization(func):
async def wrapper(*args, **kwargs):
if config.whitelist == None:
return await func(*args, **kwargs)
if (args[0].effective_chat.id not in config.whitelist):
if (args[0].effective_user.id not in config.whitelist):
message = (
f"`Hi, {args[0].effective_user.username}!`\n\n"
f"id: `{args[0].effective_user.id}`\n\n"
f"无权访问!\n\n"
)
await args[1].bot.send_message(chat_id=args[0].effective_user.id, text=message, parse_mode='MarkdownV2')
return
return await func(*args, **kwargs)
return wrapper

# 判断是否在群聊白名单
def GroupAuthorization(func):
async def wrapper(*args, **kwargs):
if config.GROUP_LIST == None:
return await func(*args, **kwargs)
if (args[0].effective_chat.id not in config.GROUP_LIST):
message = (
f"`Hi, {args[0].effective_user.username}!`\n\n"
f"id: `{args[0].effective_user.id}`\n\n"
Expand All @@ -21,13 +37,13 @@ def AdminAuthorization(func):
async def wrapper(*args, **kwargs):
if config.ADMIN_LIST == None:
return await func(*args, **kwargs)
if (args[0].effective_chat.id not in config.ADMIN_LIST):
if (args[0].effective_user.id not in config.ADMIN_LIST):
message = (
f"`Hi, {args[0].effective_user.username}!`\n\n"
f"id: `{args[0].effective_user.id}`\n\n"
f"无权访问!\n\n"
)
await args[1].bot.send_message(chat_id=args[0].effective_chat.id, text=message, parse_mode='MarkdownV2')
await args[1].bot.send_message(chat_id=args[0].effective_user.id, text=message, parse_mode='MarkdownV2')
return
return await func(*args, **kwargs)
return wrapper

0 comments on commit 6598fed

Please sign in to comment.