From 2a6c36de277145b0a0095b346054043f78f1a628 Mon Sep 17 00:00:00 2001 From: ciuzaak Date: Wed, 17 May 2023 22:15:00 +0800 Subject: [PATCH] feat: /retry to re-generate the answer refactor: Simplify error message return refactor: Optimized string processing in input_text and prompt refactor: regenerate --- README.md | 1 + bot.py | 77 +++++++++++++++++++++---------------------- utils/bard_utils.py | 18 ++++++++++ utils/claude_utils.py | 3 ++ 4 files changed, 60 insertions(+), 39 deletions(-) diff --git a/README.md b/README.md index d9d2896..42c4c9f 100644 --- a/README.md +++ b/README.md @@ -99,6 +99,7 @@ If you only have access to one of the models, you can still continue to use this 3. Input your other questions and send 4. Send `/seg` again 5. Bot will respond and you can continue the conversation +- `/retry`: regenerate the answer. Use `/retry TEXT` to modify your last input. #### Others diff --git a/bot.py b/bot.py index ffba471..a36b7be 100644 --- a/bot.py +++ b/bot.py @@ -1,4 +1,4 @@ -from re import compile, sub +from re import sub from urllib.parse import quote from telegram import (BotCommand, InlineKeyboardButton, InlineKeyboardMarkup, @@ -25,6 +25,7 @@ async def reset_chat(update: Update, context: ContextTypes.DEFAULT_TYPE): mode, session = get_session(update, context) session.reset() context.chat_data[mode].pop('last_msg_id', None) + context.chat_data[mode].pop('last_input', None) context.chat_data[mode].pop('seg_message', None) context.chat_data[mode].pop('drafts', None) await update.message.reply_text('๐Ÿงน Chat history has been reset.') @@ -66,41 +67,47 @@ async def bard_response(update: Update, context: ContextTypes.DEFAULT_TYPE): async def recv_msg(update: Update, context: ContextTypes.DEFAULT_TYPE): - if update.message.chat.type == 'private': - input_text = update.message.text - else: + input_text = update.message.text + if update.message.chat.type != 'private': if update.message.reply_to_message and update.message.reply_to_message.from_user.username == context.bot.username: - input_text = update.message.text - elif update.message.entities is not None and compile(f'@{context.bot.username}').search(update.message.text): - input_text = update.message.text.replace( - f'@{context.bot.username}', '').strip() + pass + elif update.message.entities is not None and input_text.startswith(f'@{context.bot.username}'): + input_text = input_text.lstrip(f'@{context.bot.username}').lstrip() else: return - mode, session = get_session(update, context) + # handle long message (for claude 100k model) seg_message = context.chat_data[mode].get('seg_message') if seg_message is None: if input_text.startswith('/seg'): - input_text = '/seg'.join(input_text.split('/seg')[1:]).strip() + input_text = input_text.lstrip('/seg').lstrip() if input_text.endswith('/seg'): - input_text = '/seg'.join(input_text.split('/seg')[:-1]).strip() + input_text = input_text.rstrip('/seg').rstrip() else: context.chat_data[mode]['seg_message'] = input_text return else: if input_text.endswith('/seg'): - input_text = '/seg'.join(input_text.split('/seg')[:-1]).strip() - input_text = f'{seg_message}\n\n{input_text}'.strip() + input_text = f"{seg_message}\n\n{input_text.rstrip('/seg')}".strip() context.chat_data[mode].pop('seg_message', None) else: context.chat_data[mode]['seg_message'] = f'{seg_message}\n\n{input_text}' return + # regenerate the answer + if input_text.startswith('/retry'): + last_input = context.chat_data[mode].get('last_input') + if last_input is None: + return await update.message.reply_text('โŒ Empty conversation.') + session.revert() + input_text = input_text.lstrip('/retry').lstrip() + input_text = input_text or last_input + if input_text == '': - await update.message.reply_text('โŒ Empty message.') - return + return await update.message.reply_text('โŒ Empty message.') message = await update.message.reply_text('.') + context.chat_data[mode]['last_input'] = input_text context.chat_data[mode]['last_msg_id'] = message.message_id if mode == 'Claude': @@ -128,7 +135,7 @@ async def recv_msg(update: Update, context: ContextTypes.DEFAULT_TYPE): await message.edit_text(f'โŒ Error orrurred: {e}. /reset') else: # Bard - response = session.client.ask(input_text) + response = session.send_message(input_text) # get source links sources = '' if response['factualityQueries']: @@ -179,8 +186,7 @@ async def show_settings(update: Update, context: ContextTypes.DEFAULT_TYPE): async def change_mode(update: Update, context: ContextTypes.DEFAULT_TYPE): if single_mode: - await update.message.reply_text(f'โŒ You cannot access the other mode.') - return + return await update.message.reply_text(f'โŒ You cannot access the other mode.') mode, _ = get_session(update, context) final_mode, emoji = ('Bard', '๐ŸŸ ') if mode == 'Claude' else ('Claude', '๐ŸŸฃ') @@ -200,16 +206,13 @@ async def change_model(update: Update, context: ContextTypes.DEFAULT_TYPE): mode, session = get_session(update, context) if mode == 'Bard': - await update.message.reply_text('โŒ Invalid option for Google Bard.') - return - + return await update.message.reply_text('โŒ Invalid option for Google Bard.') if len(context.args) != 1: - await update.message.reply_text('โŒ Please provide a model name.') - return + return await update.message.reply_text('โŒ Please provide a model name.') + model = context.args[0].strip() if not session.change_model(model): - await update.message.reply_text('โŒ Invalid model name.') - return + return await update.message.reply_text('โŒ Invalid model name.') await update.message.reply_text(f'๐Ÿค– Model has been switched to {model}.', parse_mode=ParseMode.HTML) @@ -218,16 +221,13 @@ async def change_temperature(update: Update, context: ContextTypes.DEFAULT_TYPE) mode, session = get_session(update, context) if mode == 'Bard': - await update.message.reply_text('โŒ Invalid option for Google Bard.') - return - + return await update.message.reply_text('โŒ Invalid option for Google Bard.') if len(context.args) != 1: - await update.message.reply_text('โŒ Please provide a temperature value.') - return + return await update.message.reply_text('โŒ Please provide a temperature value.') + temperature = context.args[0].strip() if not session.change_temperature(temperature): - await update.message.reply_text('โŒ Invalid temperature value.') - return + return await update.message.reply_text('โŒ Invalid temperature value.') await update.message.reply_text(f'๐ŸŒก๏ธ Temperature has been set to {temperature}.', parse_mode=ParseMode.HTML) @@ -236,16 +236,13 @@ async def change_cutoff(update: Update, context: ContextTypes.DEFAULT_TYPE): mode, session = get_session(update, context) if mode == 'Bard': - await update.message.reply_text('โŒ Invalid option for Google Bard.') - return - + return await update.message.reply_text('โŒ Invalid option for Google Bard.') if len(context.args) != 1: - await update.message.reply_text('โŒ Please provide a cutoff value.') - return + return await update.message.reply_text('โŒ Please provide a cutoff value.') + cutoff = context.args[0].strip() if not session.change_cutoff(cutoff): - await update.message.reply_text('โŒ Invalid cutoff value.') - return + return await update.message.reply_text('โŒ Invalid cutoff value.') await update.message.reply_text(f'โœ‚๏ธ Cutoff has been set to {cutoff}.', parse_mode=ParseMode.HTML) @@ -257,6 +254,7 @@ async def start_bot(update: Update, context: ContextTypes.DEFAULT_TYPE): 'Commands:', 'โ€ข /id to get your chat identifier', 'โ€ข /reset to reset the chat history', + 'โ€ข /retry to regenerate the answer', 'โ€ข /seg to send message in segments', 'โ€ข /mode to switch between Claude & Bard', 'โ€ข /settings to show Claude & Bard settings', @@ -278,6 +276,7 @@ async def error_handler(update: Update, context: ContextTypes.DEFAULT_TYPE): async def post_init(application: Application): await application.bot.set_my_commands([ BotCommand('/reset', 'Reset the chat history'), + BotCommand('/retry', 'Regenerate the answer'), BotCommand('/seg', 'Send message in segments'), BotCommand('/mode', 'Switch between Claude & Bard'), BotCommand('/settings', 'Show Claude & Bard settings'), diff --git a/utils/bard_utils.py b/utils/bard_utils.py index 5bbbaf4..631a640 100644 --- a/utils/bard_utils.py +++ b/utils/bard_utils.py @@ -6,8 +6,26 @@ class Bard: def __init__(self): self.client = Chatbot(bard_api) + self.prev_conversation_id = '' + self.prev_response_id = '' + self.prev_choice_id = '' def reset(self): self.client.conversation_id = '' self.client.response_id = '' self.client.choice_id = '' + self.prev_conversation_id = '' + self.prev_response_id = '' + self.prev_choice_id = '' + + def revert(self): + self.client.conversation_id = self.prev_conversation_id + self.client.response_id = self.prev_response_id + self.client.choice_id = self.prev_choice_id + + def send_message(self, message): + self.prev_conversation_id = self.client.conversation_id + self.prev_response_id = self.client.response_id + self.prev_choice_id = self.client.choice_id + response = self.client.ask(message) + return response diff --git a/utils/claude_utils.py b/utils/claude_utils.py index 4e7bd95..4e15459 100644 --- a/utils/claude_utils.py +++ b/utils/claude_utils.py @@ -14,6 +14,9 @@ def __init__(self): def reset(self): self.prompt = '' + def revert(self): + self.prompt = self.prompt[:self.prompt.rfind(HUMAN_PROMPT)] + def change_model(self, model): valid_models = { 'claude-v1',