diff --git a/README.md b/README.md
index d9d2896..42c4c9f 100644
--- a/README.md
+++ b/README.md
@@ -99,6 +99,7 @@ If you only have access to one of the models, you can still continue to use this
3. Input your other questions and send
4. Send `/seg` again
5. Bot will respond and you can continue the conversation
+- `/retry`: regenerate the answer. Use `/retry TEXT` to modify your last input.
#### Others
diff --git a/bot.py b/bot.py
index ffba471..a36b7be 100644
--- a/bot.py
+++ b/bot.py
@@ -1,4 +1,4 @@
-from re import compile, sub
+from re import sub
from urllib.parse import quote
from telegram import (BotCommand, InlineKeyboardButton, InlineKeyboardMarkup,
@@ -25,6 +25,7 @@ async def reset_chat(update: Update, context: ContextTypes.DEFAULT_TYPE):
mode, session = get_session(update, context)
session.reset()
context.chat_data[mode].pop('last_msg_id', None)
+ context.chat_data[mode].pop('last_input', None)
context.chat_data[mode].pop('seg_message', None)
context.chat_data[mode].pop('drafts', None)
await update.message.reply_text('๐งน Chat history has been reset.')
@@ -66,41 +67,47 @@ async def bard_response(update: Update, context: ContextTypes.DEFAULT_TYPE):
async def recv_msg(update: Update, context: ContextTypes.DEFAULT_TYPE):
- if update.message.chat.type == 'private':
- input_text = update.message.text
- else:
+ input_text = update.message.text
+ if update.message.chat.type != 'private':
if update.message.reply_to_message and update.message.reply_to_message.from_user.username == context.bot.username:
- input_text = update.message.text
- elif update.message.entities is not None and compile(f'@{context.bot.username}').search(update.message.text):
- input_text = update.message.text.replace(
- f'@{context.bot.username}', '').strip()
+ pass
+ elif update.message.entities is not None and input_text.startswith(f'@{context.bot.username}'):
+ input_text = input_text.lstrip(f'@{context.bot.username}').lstrip()
else:
return
-
mode, session = get_session(update, context)
+
# handle long message (for claude 100k model)
seg_message = context.chat_data[mode].get('seg_message')
if seg_message is None:
if input_text.startswith('/seg'):
- input_text = '/seg'.join(input_text.split('/seg')[1:]).strip()
+ input_text = input_text.lstrip('/seg').lstrip()
if input_text.endswith('/seg'):
- input_text = '/seg'.join(input_text.split('/seg')[:-1]).strip()
+ input_text = input_text.rstrip('/seg').rstrip()
else:
context.chat_data[mode]['seg_message'] = input_text
return
else:
if input_text.endswith('/seg'):
- input_text = '/seg'.join(input_text.split('/seg')[:-1]).strip()
- input_text = f'{seg_message}\n\n{input_text}'.strip()
+ input_text = f"{seg_message}\n\n{input_text.rstrip('/seg')}".strip()
context.chat_data[mode].pop('seg_message', None)
else:
context.chat_data[mode]['seg_message'] = f'{seg_message}\n\n{input_text}'
return
+ # regenerate the answer
+ if input_text.startswith('/retry'):
+ last_input = context.chat_data[mode].get('last_input')
+ if last_input is None:
+ return await update.message.reply_text('โ Empty conversation.')
+ session.revert()
+ input_text = input_text.lstrip('/retry').lstrip()
+ input_text = input_text or last_input
+
if input_text == '':
- await update.message.reply_text('โ Empty message.')
- return
+ return await update.message.reply_text('โ Empty message.')
message = await update.message.reply_text('.')
+ context.chat_data[mode]['last_input'] = input_text
context.chat_data[mode]['last_msg_id'] = message.message_id
if mode == 'Claude':
@@ -128,7 +135,7 @@ async def recv_msg(update: Update, context: ContextTypes.DEFAULT_TYPE):
await message.edit_text(f'โ Error orrurred: {e}. /reset')
else: # Bard
- response = session.client.ask(input_text)
+ response = session.send_message(input_text)
# get source links
sources = ''
if response['factualityQueries']:
@@ -179,8 +186,7 @@ async def show_settings(update: Update, context: ContextTypes.DEFAULT_TYPE):
async def change_mode(update: Update, context: ContextTypes.DEFAULT_TYPE):
if single_mode:
- await update.message.reply_text(f'โ You cannot access the other mode.')
- return
+ return await update.message.reply_text(f'โ You cannot access the other mode.')
mode, _ = get_session(update, context)
final_mode, emoji = ('Bard', '๐ ') if mode == 'Claude' else ('Claude', '๐ฃ')
@@ -200,16 +206,13 @@ async def change_model(update: Update, context: ContextTypes.DEFAULT_TYPE):
mode, session = get_session(update, context)
if mode == 'Bard':
- await update.message.reply_text('โ Invalid option for Google Bard.')
- return
-
+ return await update.message.reply_text('โ Invalid option for Google Bard.')
if len(context.args) != 1:
- await update.message.reply_text('โ Please provide a model name.')
- return
+ return await update.message.reply_text('โ Please provide a model name.')
+
model = context.args[0].strip()
if not session.change_model(model):
- await update.message.reply_text('โ Invalid model name.')
- return
+ return await update.message.reply_text('โ Invalid model name.')
await update.message.reply_text(f'๐ค Model has been switched to {model}.',
parse_mode=ParseMode.HTML)
@@ -218,16 +221,13 @@ async def change_temperature(update: Update, context: ContextTypes.DEFAULT_TYPE)
mode, session = get_session(update, context)
if mode == 'Bard':
- await update.message.reply_text('โ Invalid option for Google Bard.')
- return
-
+ return await update.message.reply_text('โ Invalid option for Google Bard.')
if len(context.args) != 1:
- await update.message.reply_text('โ Please provide a temperature value.')
- return
+ return await update.message.reply_text('โ Please provide a temperature value.')
+
temperature = context.args[0].strip()
if not session.change_temperature(temperature):
- await update.message.reply_text('โ Invalid temperature value.')
- return
+ return await update.message.reply_text('โ Invalid temperature value.')
await update.message.reply_text(f'๐ก๏ธ Temperature has been set to {temperature}.',
parse_mode=ParseMode.HTML)
@@ -236,16 +236,13 @@ async def change_cutoff(update: Update, context: ContextTypes.DEFAULT_TYPE):
mode, session = get_session(update, context)
if mode == 'Bard':
- await update.message.reply_text('โ Invalid option for Google Bard.')
- return
-
+ return await update.message.reply_text('โ Invalid option for Google Bard.')
if len(context.args) != 1:
- await update.message.reply_text('โ Please provide a cutoff value.')
- return
+ return await update.message.reply_text('โ Please provide a cutoff value.')
+
cutoff = context.args[0].strip()
if not session.change_cutoff(cutoff):
- await update.message.reply_text('โ Invalid cutoff value.')
- return
+ return await update.message.reply_text('โ Invalid cutoff value.')
await update.message.reply_text(f'โ๏ธ Cutoff has been set to {cutoff}.',
parse_mode=ParseMode.HTML)
@@ -257,6 +254,7 @@ async def start_bot(update: Update, context: ContextTypes.DEFAULT_TYPE):
'Commands:',
'โข /id to get your chat identifier',
'โข /reset to reset the chat history',
+ 'โข /retry to regenerate the answer',
'โข /seg to send message in segments',
'โข /mode to switch between Claude & Bard',
'โข /settings to show Claude & Bard settings',
@@ -278,6 +276,7 @@ async def error_handler(update: Update, context: ContextTypes.DEFAULT_TYPE):
async def post_init(application: Application):
await application.bot.set_my_commands([
BotCommand('/reset', 'Reset the chat history'),
+ BotCommand('/retry', 'Regenerate the answer'),
BotCommand('/seg', 'Send message in segments'),
BotCommand('/mode', 'Switch between Claude & Bard'),
BotCommand('/settings', 'Show Claude & Bard settings'),
diff --git a/utils/bard_utils.py b/utils/bard_utils.py
index 5bbbaf4..631a640 100644
--- a/utils/bard_utils.py
+++ b/utils/bard_utils.py
@@ -6,8 +6,26 @@
class Bard:
def __init__(self):
self.client = Chatbot(bard_api)
+ self.prev_conversation_id = ''
+ self.prev_response_id = ''
+ self.prev_choice_id = ''
def reset(self):
self.client.conversation_id = ''
self.client.response_id = ''
self.client.choice_id = ''
+ self.prev_conversation_id = ''
+ self.prev_response_id = ''
+ self.prev_choice_id = ''
+
+ def revert(self):
+ self.client.conversation_id = self.prev_conversation_id
+ self.client.response_id = self.prev_response_id
+ self.client.choice_id = self.prev_choice_id
+
+ def send_message(self, message):
+ self.prev_conversation_id = self.client.conversation_id
+ self.prev_response_id = self.client.response_id
+ self.prev_choice_id = self.client.choice_id
+ response = self.client.ask(message)
+ return response
diff --git a/utils/claude_utils.py b/utils/claude_utils.py
index 4e7bd95..4e15459 100644
--- a/utils/claude_utils.py
+++ b/utils/claude_utils.py
@@ -14,6 +14,9 @@ def __init__(self):
def reset(self):
self.prompt = ''
+ def revert(self):
+ self.prompt = self.prompt[:self.prompt.rfind(HUMAN_PROMPT)]
+
def change_model(self, model):
valid_models = {
'claude-v1',