diff --git a/lambda/bard_engine.py b/lambda/bard_engine.py new file mode 100644 index 0000000..614969b --- /dev/null +++ b/lambda/bard_engine.py @@ -0,0 +1,68 @@ +import json +import logging +import re +import uuid + +import utils +from Bard import Chatbot +from conversation_history import ConversationHistory +from engine_interface import EngineInterface + +logging.basicConfig() +logging.getLogger().setLevel("INFO") + +class BardEngine(EngineInterface): + + def __init__(self, chatbot: Chatbot) -> None: + self.conversation_id = None + self.parent_id = str(uuid.uuid4()) + self.chatbot = chatbot + self.history = ConversationHistory() + + def reset_chat(self) -> None: + self.conversation_id = None + self.parent_id = str(uuid.uuid4()) + self.chatbot.reset() + + async def ask_async(self, text: str, userConfig: dict) -> str: + # style = userConfig.get("style", "balanced") + response = self.chatbot.ask(message=text) + # item = response["item"] + # self.conversation_id = item["conversationId"] + # try: + # await self.history.write_async( + # conversation_id=self.conversation_id, + # request_id=item["requestId"], + # user_id=userConfig["user_id"], + # conversation=item) + # except Exception as e: + # logging.error(f"conversation_id: {self.conversation_id}, error: {e}") + # finally: + logging.info(json.dumps(response, default=vars)) + return response["message"] + # if "plaintext" in userConfig is True: + # return utils.read_plain_text(response) + # return BardEngine.read_markdown(response) + + async def close(self): + await self.chatbot.close() + + @property + def engine_type(self): + return "bard" + + @classmethod + def create(cls) -> EngineInterface: + token = utils.read_ssm_param(param_name="BARD_TOKEN") + chatbot = Chatbot(session_id=token) + return BardEngine(chatbot) + + @classmethod + def read_plain_text(cls, response: dict) -> str: + return re.sub(pattern=cls.remove_links_pattern, repl="", + string=response["item"]["messages"][1]["text"]) + + @classmethod + def read_markdown(cls, response: dict) -> str: + message = response["item"]["messages"][1]["adaptiveCards"][0]["body"][0]["text"] + return utils.replace_references(text=message) diff --git a/lambda/bing_gpt.py b/lambda/bing_gpt.py index 3e7f336..80ff7af 100644 --- a/lambda/bing_gpt.py +++ b/lambda/bing_gpt.py @@ -4,36 +4,55 @@ import uuid import boto3 -import markdown +import utils +from conversation_history import ConversationHistory from EdgeGPT import Chatbot +from engine_interface import EngineInterface logging.basicConfig() logging.getLogger().setLevel("INFO") -class BingGpt: - ref_link_pattern = re.compile(r"\[(.*?)\]\:\s?(.*?)\s\"(.*?)\"\n?") - esc_pattern = re.compile(f"(? None: - _ssm_client = boto3.client(service_name="ssm") - s3_path = _ssm_client.get_parameter(Name="COOKIES_FILE")["Parameter"]["Value"] +class BingGpt(EngineInterface): + + def __init__(self, chatbot: Chatbot) -> None: self.conversation_id = None self.parent_id = str(uuid.uuid4()) - cookies = self.read_cookies(s3_path) - self.chatbot = Chatbot(cookies=cookies) + self.chatbot = chatbot + self.history = ConversationHistory() def reset_chat(self) -> None: self.conversation_id = None self.parent_id = str(uuid.uuid4()) + self.chatbot.reset() - async def ask(self, text: str, userConfig: dict) -> str: - response = await self.chatbot.ask(prompt=text) - logging.info(json.dumps(response, default=vars)) - return response - + async def ask_async(self, text: str, userConfig: dict) -> str: + style = userConfig.get("style", "balanced") + response = await self.chatbot.ask(prompt=text, + conversation_style=style) + item = response["item"] + self.conversation_id = item["conversationId"] + try: + await self.history.write_async( + conversation_id=self.conversation_id, + request_id=item["requestId"], + user_id=userConfig["user_id"], + conversation=item) + except Exception as e: + logging.error(f"conversation_id: {self.conversation_id}, error: {e}") + finally: + logging.info(json.dumps(response, default=vars)) + + if "plaintext" in userConfig is True: + return utils.read_plain_text(response) + return BingGpt.read_markdown(response) + async def close(self): await self.chatbot.close() + @property + def engine_type(self): + return "bing" + def read_cookies(self, s3_path) -> dict: s3 = boto3.client("s3") bucket_name, file_name = s3_path.replace("s3://", "").split("/", 1) @@ -42,37 +61,20 @@ def read_cookies(self, s3_path) -> dict: return json.loads(file_content) @classmethod - def read_plain_text(cls, response: dict) -> str: - return BingGpt.remove_links(text=response["item"]["messages"][1]["text"]) + def create(cls) -> EngineInterface: + s3_path = utils.read_ssm_param(param_name="COOKIES_FILE") + bucket_name, file_name = s3_path.replace("s3://", "").split("/", 1) + chatbot = Chatbot(cookies=utils.read_json_from_s3(bucket_name, file_name)) + bing = BingGpt(chatbot) + return bing @classmethod - def read_markdown(cls, response: dict) -> str: - message = response["item"]["messages"][1]["adaptiveCards"][0]["body"][0]["text"] - return BingGpt.replace_references(text=message) + def read_plain_text(cls, response: dict) -> str: + return re.sub(pattern=cls.remove_links_pattern, repl="", + string=response["item"]["messages"][1]["text"]) @classmethod - def read_html(cls, response: dict) -> str: + def read_markdown(cls, response: dict) -> str: message = response["item"]["messages"][1]["adaptiveCards"][0]["body"][0]["text"] - text = BingGpt.replace_references(text=message) - return markdown.markdown(text=text) - - @classmethod - def replace_references(cls, text: str) -> str: - ref_links = re.findall(pattern=cls.ref_link_pattern, string=text) - text = re.sub(pattern=cls.ref_link_pattern, repl="", string=text) - text = BingGpt.escape_markdown_v2(text=text) - for link in ref_links: - link_label = link[0] - link_ref = link[1] - inline_link = f" [\[{link_label}\]]({link_ref})" - text = re.sub(pattern=rf"\[\^{link_label}\^\]\[\d+\]", - repl=inline_link, string=text) - return text + return utils.replace_references(text=message) - @classmethod - def remove_links(cls, text: str) -> str: - return re.sub(pattern=r"\[\^\d+\^\]\s?", repl="", string=text) - - @classmethod - def escape_markdown_v2(cls, text: str) -> str: - return re.sub(pattern=cls.esc_pattern, repl=r"\\\1", string=text) diff --git a/lambda/chatbot.py b/lambda/chatbot.py index 6cc5bdb..4d74021 100644 --- a/lambda/chatbot.py +++ b/lambda/chatbot.py @@ -2,9 +2,11 @@ import json import logging -import boto3 +import utils +from bard_engine import BardEngine from bing_gpt import BingGpt -from telegram import Update, constants +from engine_interface import EngineInterface +from telegram import Update from telegram.ext import ( Application, CommandHandler, @@ -13,7 +15,6 @@ filters, ) from user_config import UserConfig -from utils import generate_transcription, send_typing_action example_tg = ''' *bold \*text* @@ -36,36 +37,47 @@ logging.basicConfig() logging.getLogger().setLevel("INFO") -bing = BingGpt() + user_config = UserConfig() +engines = {} -telegram_token = boto3.client(service_name="ssm").get_parameter(Name="TELEGRAM_TOKEN")[ - "Parameter" -]["Value"] +telegram_token = utils.read_ssm_param(param_name="TELEGRAM_TOKEN") app = Application.builder().token(token=telegram_token).build() bot = app.bot logging.info("application startup") # Telegram commands -async def reset(update, context: ContextTypes.DEFAULT_TYPE) -> None: - bing.reset_chat() - await context.bot.send_message( - chat_id=update.message.chat_id, text="Conversation has been reset" +async def reset(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: + user_id = update.effective_user.id + config = user_config.read(user_id) + engine = get_engine(config) + engine.reset_chat() + await update.message.reply_text( + text="Conversation has been reset" ) -async def set_engine(update, context: ContextTypes.DEFAULT_TYPE) -> None: - user_id = update.message.from_user.id - logging.info(f"user_id: {user_id}") - config = user_config.create_config(user_id) - logging.info(f"config: {config}") - engine = update.message.text.strip("/").lower() - logging.info(f"engine: {engine}") - config["engine"] = engine - logging.info(f"config: {config}") +async def set_style(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: + user_id = update.effective_user.id + config = user_config.read(user_id) + style = update.message.text.strip("/").lower() + config["style"] = style + logging.info(f"user: {user_id} set engine style to: '{style}'") user_config.write(user_id, config) await update.message.reply_text( - text=f"Bot engine has been set to {engine}" + text=f"Bot engine style has been set to '{style}'" + ) + +async def set_engine(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: + user_id = update.effective_user.id + config = user_config.read(user_id) + engine_type = update.message.text.strip("/").lower() + logging.info(f"engine: {engine_type}") + config["engine"] = engine_type + logging.info(f"user: {user_id} set engine to: {engine_type}") + user_config.write(user_id, config) + await update.message.reply_text( + text=f"Bot engine has been set to {engine_type}" ) async def set_plaintext(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: @@ -85,56 +97,68 @@ async def send_example(update: Update, context: ContextTypes.DEFAULT_TYPE) -> No # Telegram handlers -# @send_typing_action -async def process_voice_message(update, context: ContextTypes.DEFAULT_TYPE): +async def process_voice_message(update: Update, context: ContextTypes.DEFAULT_TYPE): voice_message = update.message.voice file_id = voice_message.file_id file = await bot.get_file(file_id) - transcript_msg = generate_transcription(file) + transcript_msg = utils.generate_transcription(file) logging.info(transcript_msg) - user_id = int(update.message.from_user.id) - config = user_config.read(user_id) - message = await bing.ask(transcript_msg, config) - chat_id = update.message.chat_id - await context.bot.send_message( - chat_id=chat_id, - text=message, - parse_mode=constants.ParseMode.MARKDOWN_V2, - ) - - -async def process_message(update, context: ContextTypes.DEFAULT_TYPE): - if update.message.text is None: - return - if bot.name not in update.message.text and "group" in update.message.chat.type: - return try: - await processing_internal(update, context) + user_id = int(update.effective_message.from_user.id) + config = user_config.read(user_id) + engine = get_engine(config) + await process_internal(engine, config, update, context) except Exception as e: logging.error(e) -@send_typing_action -async def processing_internal(update, context: ContextTypes.DEFAULT_TYPE): - # chat_id = update.message.chat_id - chat_text = update.message.text.replace(bot.name, "") +async def process_message(update: Update, context: ContextTypes.DEFAULT_TYPE): + if update.effective_message is None: + logging.info(update) + return + if bot.name not in update.effective_message.text: + # and update.effective_message.chat.type == constants.ChatType.GROUP: + return try: - user_id = int(update.message.from_user.id) + user_id = int(update.effective_user.id) config = user_config.read(user_id) - response = await bing.ask(chat_text, config) - if "plaintext" in config is True: - await update.message.reply_text( - text=BingGpt.read_plain_text(response=response), - disable_notification=True) - else: - await update.message.reply_markdown_v2( - text=BingGpt.read_markdown(response=response), - disable_notification=True, - disable_web_page_preview=True) + engine = get_engine(config) + await process_internal(update, context, engine, config) except Exception as e: logging.error(e) +@utils.send_typing_action +async def process_internal(update: Update, context: ContextTypes.DEFAULT_TYPE, + engine: EngineInterface, config: UserConfig): + chat_text = update.effective_message.text.replace(bot.name, "") + response = await engine.ask_async(chat_text, config) + + if "plaintext" in config is True: + await update.effective_message.reply_text( + text=response, + disable_notification=True, + disable_web_page_preview=True) + else: + await update.effective_message.reply_markdown_v2( + text=response, + disable_notification=True, + disable_web_page_preview=True) + + +def get_engine(config: UserConfig) -> EngineInterface: + engine_type = config["engine"] + if engine_type in engines: + return engines[engine_type] + engine: EngineInterface = None + if engine_type == "bing": + engine = BingGpt.create() + elif engine_type == "bard": + engine = BardEngine.create() + engines[engine_type] = engine + return engine + + # Lambda message handler def message_handler(event, context): @@ -147,14 +171,14 @@ async def main(event): filters=filters.COMMAND)) app.add_handler(CommandHandler(["plaintext", "markdown"], set_plaintext, filters=filters.COMMAND)) + app.add_handler(CommandHandler(["creative", "balanced", "precise"], set_style, + filters=filters.COMMAND)) app.add_handler(CommandHandler("example", send_example, filters=filters.COMMAND)) - app.add_handler(MessageHandler(filters.TEXT, process_message)) - # app.add_handler(MessageHandler(filters.CHAT, process_message)) + app.add_handler(MessageHandler(filters.ALL, process_message)) app.add_handler(MessageHandler(filters.VOICE, process_voice_message)) try: await app.initialize() await app.process_update(Update.de_json(json.loads(event["body"]), bot)) - return {"statusCode": 200, "body": "Success"} except Exception as ex: diff --git a/lambda/chatgpt.py b/lambda/chatgpt.py index 3c635db..9f65393 100644 --- a/lambda/chatgpt.py +++ b/lambda/chatgpt.py @@ -52,7 +52,6 @@ def ask(self, text) -> str: message = "..." logging.error(f"Request:{data}, Response:{response}") else: - # logging.info(response.text) result = response.json() self.parent_id = result["id"] self.conversation_id = result["conversation_id"] diff --git a/lambda/chatsonic.py b/lambda/chatsonic.py index 7f0f8e6..7a59c67 100644 --- a/lambda/chatsonic.py +++ b/lambda/chatsonic.py @@ -1,13 +1,17 @@ +import json import logging +import re import uuid import boto3 -from Conversation import Conversation +import utils +from conversation import Conversation +from engine_interface import EngineInterface logging.getLogger().setLevel("INFO") -class ChatSonic: +class ChatSonic(EngineInterface): def __init__(self) -> None: _ssm_client = boto3.client(service_name="ssm") token = _ssm_client.get_parameter(Name="CHATSONIC_TOKEN")["Parameter"]["Value"] @@ -19,9 +23,10 @@ def reset_chat(self) -> None: self.conversation_id = None self.parent_id = str(uuid.uuid4()) - def ask(self, text) -> str: + async def ask_async(self, text: str, userConfig: dict) -> str: try: response = self.chatsonic.send_message(message=text) + logging.info(json.dumps(response, default=vars)) except Exception as e: logging.error(e) @@ -31,4 +36,24 @@ def ask(self, text) -> str: message = next( (r["message"] for r in response if r["is_sent"] is False), "" ) + if "plaintext" in userConfig is True: + return ChatSonic.read_plain_text(message) + return ChatSonic.read_markdown(message) return message + + def close(self): + pass + + @property + def engine_type(self): + return "chatsonic" + + @classmethod + def read_plain_text(cls, response: dict) -> str: + return re.sub(pattern=cls.remove_links_pattern, repl="", + string=response["item"]["messages"][1]["text"]) + + @classmethod + def read_markdown(cls, response: dict) -> str: + message = response["item"]["messages"][1]["adaptiveCards"][0]["body"][0]["text"] + return utils.escape_markdown_v2(text=message) diff --git a/lambda/conversation_history.py b/lambda/conversation_history.py new file mode 100644 index 0000000..a6a55fe --- /dev/null +++ b/lambda/conversation_history.py @@ -0,0 +1,40 @@ +import asyncio +import json +import logging +import time +from typing import Optional + +import boto3 + +logging.basicConfig() +logging.getLogger().setLevel("INFO") + + +class ConversationHistory: + def __init__(self) -> None: + dynamodb = boto3.resource("dynamodb") + self.table = dynamodb.Table("conversations") + + def read(self, conversation_id: str, request_id: str) -> Optional[dict]: + try: + resp = self.table.get_item(Key = { + "conversation_id" : conversation_id, + "request_id": request_id + }) + if "Item" in resp: + return json.loads(resp["Item"]["conversation"]) + return None + except Exception as e: + logging.error(e) + + async def write_async(self, conversation_id: str, request_id: str, user_id: int, + conversation): + asyncio.run(self.table.put_item( + Item={ + "conversation_id": conversation_id, + "request_id": request_id, + "user_id": user_id, + "timestamp": int(time.time()), + "conversation": json.dumps(conversation) + }) + ) diff --git a/lambda/engine_interface.py b/lambda/engine_interface.py new file mode 100644 index 0000000..c32b580 --- /dev/null +++ b/lambda/engine_interface.py @@ -0,0 +1,21 @@ +from abc import ABC, abstractmethod + + +class EngineInterface(ABC): + + @abstractmethod + async def ask_async(self, text, userConfig: dict) -> str: + pass + + @abstractmethod + def reset_chat(self): + pass + + @abstractmethod + def close(self): + pass + + @property + @abstractmethod + def engine_type(self) -> str: + pass diff --git a/lambda/requirements.txt b/lambda/requirements.txt index c32c9d7..70f3842 100644 --- a/lambda/requirements.txt +++ b/lambda/requirements.txt @@ -1,7 +1,7 @@ python-telegram-bot~=20.1 git+https://github.com/brainboost/EdgeGPT.git +GoogleBard websockets requests boto3 wget -markdown \ No newline at end of file diff --git a/lambda/user_config.py b/lambda/user_config.py index 268e799..543a042 100644 --- a/lambda/user_config.py +++ b/lambda/user_config.py @@ -13,7 +13,7 @@ def __init__(self) -> None: dynamodb = boto3.resource("dynamodb") self.table = dynamodb.Table("user-configurations") - def read(self, user_id: int): + def read(self, user_id: int) -> dict: try: resp = self.table.get_item(Key = { "user_id" : user_id }) if "Item" in resp: @@ -23,13 +23,15 @@ def read(self, user_id: int): except Exception as e: logging.error(e) - def write(self, user_id: int, config): + def write(self, user_id: int, config): + config["user_id"] = user_id self.table.put_item( Item={"user_id": user_id, "config": json.dumps(config)}) - def create_config(self, user_id: int): + def create_config(self, user_id: int) -> dict: return { "plaintext": False, + "user_id": user_id, "engine": "bing", "updated": int(time.time()), } diff --git a/lambda/utils.py b/lambda/utils.py index a2255e2..37eaea0 100644 --- a/lambda/utils.py +++ b/lambda/utils.py @@ -1,42 +1,35 @@ import json import logging import os +import re import uuid from functools import wraps import boto3 import wget -from telegram import constants +from telegram import Update, constants logging.basicConfig() logging.getLogger().setLevel("INFO") +ref_link_pattern = re.compile(r"\[(.*?)\]\:\s?(.*?)\s\"(.*?)\"\n?") +esc_pattern = re.compile(f"(? str: + ssm_client = boto3.client(service_name="ssm") + return ssm_client.get_parameter(Name=param_name)[ + "Parameter" + ]["Value"] + +def read_json_from_s3(bucket_name: str, file_name: str) -> dict: + s3 = boto3.client("s3") + response = s3.get_object(Bucket=bucket_name, Key=file_name) + file_content = response["Body"].read().decode("utf-8") + return json.loads(file_content) + +def replace_references(text: str) -> str: + ref_links = re.findall(pattern=ref_link_pattern, string=text) + text = re.sub(pattern=ref_link_pattern, repl="", string=text) + text = escape_markdown_v2(text=text) + for link in ref_links: + link_label = link[0] + link_ref = link[1] + inline_link = f" [\[{link_label}\]]({link_ref})" + text = re.sub(pattern=rf"\[\^{link_label}\^\]\[\d+\]", + repl=inline_link, string=text) + return text + +def escape_markdown_v2(text: str) -> str: + return re.sub(pattern=esc_pattern, repl=r"\\\1", string=text) diff --git a/stacks/chatgpt_bot_stack.py b/stacks/chatgpt_bot_stack.py index c70430d..496b426 100644 --- a/stacks/chatgpt_bot_stack.py +++ b/stacks/chatgpt_bot_stack.py @@ -15,7 +15,7 @@ def __init__(self, scope: Construct, construct_id: str, **kwargs) -> None: lambda_role = iam.Role( self, - "BotRole", + "ChatBotRole", assumed_by=iam.ServicePrincipal("lambda.amazonaws.com"), managed_policies=[ iam.ManagedPolicy.from_aws_managed_policy_name( @@ -31,8 +31,8 @@ def __init__(self, scope: Construct, construct_id: str, **kwargs) -> None: self.bucket = s3.Bucket( self, - f"{construct_id}-Bucket", - bucket_name=f"{construct_id}-s3-bucket".lower(), + f"{construct_id}-s3-Bucket", + bucket_name=f"{construct_id}-s3-bucket-temp".lower(), removal_policy=_removalpolicy.DESTROY, block_public_access=s3.BlockPublicAccess.BLOCK_ALL, versioned=True, diff --git a/stacks/database_stack.py b/stacks/database_stack.py index c1936e7..d6aa177 100644 --- a/stacks/database_stack.py +++ b/stacks/database_stack.py @@ -19,7 +19,7 @@ def __init__(self, scope: Construct, id: str, **kwargs) -> None: billing_mode=dynamodb.BillingMode.PAY_PER_REQUEST ) - dynamodb.Table( + conversationTable = dynamodb.Table( self, "conversations-table", table_name="conversations", partition_key=dynamodb.Attribute( @@ -27,9 +27,27 @@ def __init__(self, scope: Construct, id: str, **kwargs) -> None: type=dynamodb.AttributeType.STRING ), sort_key=dynamodb.Attribute( - name="user_id", - type=dynamodb.AttributeType.NUMBER + name="request_id", + type=dynamodb.AttributeType.STRING ), removal_policy=RemovalPolicy.RETAIN, billing_mode=dynamodb.BillingMode.PAY_PER_REQUEST + ) + + conversationTable.add_global_secondary_index( + index_name="user-id-index", + partition_key=dynamodb.Attribute( + name="user_id", + type=dynamodb.AttributeType.NUMBER + ), + projection_type=dynamodb.ProjectionType.ALL + ) + + conversationTable.add_local_secondary_index( + index_name="updated-index", + sort_key=dynamodb.Attribute( + name="timestamp", + type=dynamodb.AttributeType.NUMBER + ), + projection_type=dynamodb.ProjectionType.ALL ) \ No newline at end of file