-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #2 from brainboost/development
Adding pub/sub and images generation
- Loading branch information
Showing
27 changed files
with
939 additions
and
548 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,41 @@ | ||
#!/usr/bin/env python3 | ||
import os | ||
|
||
import aws_cdk as cdk | ||
|
||
from stacks.chatgpt_bot_stack import ChatgptBotStack | ||
from stacks.chatbot_stack import ChatBotStack | ||
from stacks.database_stack import DatabaseStack | ||
from stacks.engines_stack import EnginesStack | ||
|
||
app = cdk.App() | ||
botStack = ChatgptBotStack(app, "ChatgptBotStack") | ||
databaseStack = DatabaseStack(app, "DatabaseStack") | ||
|
||
env = cdk.Environment( | ||
account=os.environ["CDK_DEFAULT_ACCOUNT"], region=os.environ["CDK_DEFAULT_REGION"] | ||
) | ||
stage = os.environ.get("STAGE", "dev") | ||
|
||
engStack = EnginesStack( | ||
scope=app, | ||
construct_id="EnginesStack", | ||
description="A stack that creates lambda functions working with AI engines APIs", | ||
env=env, | ||
) | ||
|
||
databaseStack = DatabaseStack( | ||
scope=app, | ||
construct_id="DatabaseStack", | ||
description="A stack containing database", | ||
env=env, | ||
) | ||
|
||
botStack = ChatBotStack( | ||
scope=app, | ||
construct_id="ChatBotStack", | ||
description="A stack containing telegram bot lambda", | ||
env=env, | ||
stage=stage, | ||
) | ||
botStack.add_dependency(engStack) | ||
botStack.add_dependency(databaseStack) | ||
|
||
app.synth() |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
import asyncio | ||
import json | ||
import logging | ||
import re | ||
import uuid | ||
|
||
import boto3 | ||
import common_utils as utils | ||
from conversation_history import ConversationHistory | ||
from EdgeGPT import Chatbot | ||
from engine_interface import EngineInterface | ||
|
||
logging.basicConfig() | ||
logging.getLogger().setLevel("INFO") | ||
|
||
|
||
class BingGpt(EngineInterface): | ||
def __init__(self, chatbot: Chatbot) -> None: | ||
self.remove_links_pattern = re.compile(r"\[\^\d+\^\]\s?") | ||
self.ref_link_pattern = re.compile(r"\[(.*?)\]\:\s?(.*?)\s\"(.*?)\"\n?") | ||
self.esc_pattern = re.compile(f"(?<!\|)([{re.escape(r'.-+#|{}!=()<>')}])(?!\|)") | ||
self.conversation_id = None | ||
self.parent_id = str(uuid.uuid4()) | ||
self.chatbot = chatbot | ||
self.history = ConversationHistory() | ||
|
||
def reset_chat(self) -> None: | ||
self.conversation_id = None | ||
self.parent_id = str(uuid.uuid4()) | ||
self.chatbot.reset() | ||
|
||
def ask(self, text: str, userConfig: dict) -> str: | ||
if "/ping" in text: | ||
return "pong" | ||
style = userConfig.get("style", "creative") | ||
response = asyncio.run(self.chatbot.ask(prompt=text, conversation_style=style)) | ||
item = response["item"] | ||
self.conversation_id = item["conversationId"] | ||
try: | ||
self.history.write( | ||
conversation_id=self.conversation_id, | ||
request_id=item["requestId"], | ||
user_id=userConfig["user_id"], | ||
conversation=item, | ||
) | ||
except Exception as e: | ||
logging.error(f"conversation_id: {self.conversation_id}, error: {e}") | ||
finally: | ||
logging.info(json.dumps(response, default=vars)) | ||
|
||
if userConfig["plaintext"]: | ||
return self.read_plain_text(response) | ||
return self.read_markdown(response) | ||
|
||
def close(self): | ||
self.chatbot.close() | ||
|
||
@property | ||
def engine_type(self): | ||
return "bing" | ||
|
||
def read_cookies(self, s3_path) -> dict: | ||
s3 = boto3.client("s3") | ||
bucket_name, file_name = s3_path.replace("s3://", "").split("/", 1) | ||
response = s3.get_object(Bucket=bucket_name, Key=file_name) | ||
file_content = response["Body"].read().decode("utf-8") | ||
return json.loads(file_content) | ||
|
||
def read_plain_text(self, response: dict) -> str: | ||
return re.sub( | ||
pattern=self.remove_links_pattern, | ||
repl="", | ||
string=response["item"]["messages"][1]["text"], | ||
) | ||
|
||
def read_markdown(self, response: dict) -> str: | ||
message = response["item"]["messages"][1]["adaptiveCards"][0]["body"][0]["text"] | ||
logging.info(message) | ||
return self.replace_references(text=message) | ||
|
||
def replace_references(self, text: str) -> str: | ||
ref_links = re.findall(pattern=self.ref_link_pattern, string=text) | ||
text = re.sub(pattern=self.ref_link_pattern, repl="", string=text) | ||
text = re.sub(pattern=self.esc_pattern, repl=r"\\\1", string=text) | ||
for link in ref_links: | ||
link_label = link[0] | ||
link_ref = link[1] | ||
inline_link = f" [\[{link_label}\]]({link_ref})" | ||
text = re.sub( | ||
pattern=rf"\[\^{link_label}\^\]\[\d+\]", repl=inline_link, string=text | ||
) | ||
return text | ||
|
||
@classmethod | ||
def create(cls) -> EngineInterface: | ||
s3_path = utils.read_ssm_param(param_name="COOKIES_FILE") | ||
bucket_name, file_name = s3_path.replace("s3://", "").split("/", 1) | ||
chatbot = asyncio.run( | ||
Chatbot.create(cookies=utils.read_json_from_s3(bucket_name, file_name)) | ||
) | ||
return BingGpt(chatbot) | ||
|
||
|
||
bing = BingGpt.create() | ||
results_queue = utils.read_ssm_param(param_name="RESULTS_SQS_QUEUE_URL") | ||
sqs = boto3.session.Session().client("sqs") | ||
|
||
# AWS SQS handler | ||
|
||
|
||
def sqs_handler(event, context): | ||
for record in event["Records"]: | ||
payload = json.loads(record["body"]) | ||
logging.info(payload) | ||
response = bing.ask(payload["text"], payload["config"]) | ||
logging.info(response) | ||
payload["response"] = utils.encode_message(response) | ||
payload["engine"] = bing.engine_type | ||
logging.info(payload) | ||
sqs.send_message(QueueUrl=results_queue, MessageBody=json.dumps(payload)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
import base64 | ||
import json | ||
import logging | ||
import zlib | ||
|
||
import boto3 | ||
|
||
logging.basicConfig() | ||
logging.getLogger().setLevel("INFO") | ||
|
||
|
||
def read_ssm_param(param_name: str) -> str: | ||
ssm_client = boto3.client(service_name="ssm") | ||
return ssm_client.get_parameter(Name=param_name)["Parameter"]["Value"] | ||
|
||
|
||
def read_json_from_s3(bucket_name: str, file_name: str) -> dict: | ||
s3 = boto3.client("s3") | ||
response = s3.get_object(Bucket=bucket_name, Key=file_name) | ||
file_content = response["Body"].read().decode("utf-8") | ||
return json.loads(file_content) | ||
|
||
|
||
def encode_message(text: str) -> str: | ||
zipped = zlib.compress(text.encode("utf-8")) | ||
return base64.b64encode(zipped).decode("ascii") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import json | ||
import logging | ||
import time | ||
from typing import Optional | ||
|
||
import boto3 | ||
|
||
logging.basicConfig() | ||
logging.getLogger().setLevel("INFO") | ||
|
||
|
||
class ConversationHistory: | ||
def __init__(self) -> None: | ||
dynamodb = boto3.resource("dynamodb") | ||
self.table = dynamodb.Table("conversations") | ||
|
||
def read(self, conversation_id: str, request_id: str) -> Optional[dict]: | ||
try: | ||
resp = self.table.get_item( | ||
Key={"conversation_id": conversation_id, "request_id": request_id} | ||
) | ||
if "Item" in resp: | ||
return json.loads(resp["Item"]["conversation"]) | ||
return None | ||
except Exception as e: | ||
logging.error(e) | ||
|
||
def write(self, conversation_id: str, request_id: str, user_id: int, conversation): | ||
self.table.put_item( | ||
Item={ | ||
"conversation_id": conversation_id, | ||
"request_id": request_id, | ||
"user_id": user_id, | ||
"timestamp": int(time.time()), | ||
"conversation": json.dumps(conversation), | ||
} | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
import json | ||
import logging | ||
|
||
import boto3 | ||
import common_utils as utils | ||
from BingImageCreator import ImageGen | ||
|
||
logging.basicConfig() | ||
logging.getLogger().setLevel("INFO") | ||
|
||
|
||
def create() -> ImageGen: | ||
s3_path = utils.read_ssm_param(param_name="COOKIES_FILE") | ||
bucket_name, file_name = s3_path.replace("s3://", "").split("/", 1) | ||
auth_cookies = utils.read_json_from_s3(bucket_name, file_name) | ||
u = [x.get("value") for x in auth_cookies if x.get("name") == "_U"][0] | ||
return ImageGen(u) | ||
|
||
|
||
imageGen = create() | ||
results_queue = utils.read_ssm_param(param_name="RESULTS_SQS_QUEUE_URL") | ||
sqs = boto3.session.Session().client("sqs") | ||
|
||
|
||
def sqs_handler(event, context): | ||
for record in event["Records"]: | ||
payload = json.loads(record["body"]) | ||
logging.info(payload) | ||
prompt = payload["text"] | ||
list: list[str] = [] | ||
if prompt is not None and prompt.strip(): | ||
list = imageGen.get_images(prompt) | ||
else: | ||
# for testing purposes | ||
list = [ | ||
"https://picsum.photos/200#1", | ||
"https://picsum.photos/200#2", | ||
"https://picsum.photos/200#3", | ||
"https://picsum.photos/200#4", | ||
] | ||
logging.info(list) | ||
message = "\n".join(list) | ||
payload["response"] = utils.encode_message(message) | ||
payload["engine"] = "Dall-E" | ||
logging.info(payload) | ||
sqs.send_message(QueueUrl=results_queue, MessageBody=json.dumps(payload)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
git+https://github.com/brainboost/EdgeGPT.git | ||
GoogleBard | ||
BingImageCreator | ||
websockets | ||
requests | ||
boto3 | ||
wget |
Oops, something went wrong.