Skip to content

Commit 9e5ef89

Browse files
author
starpig1129
committed
Rewrite the program structure (remove bot separately)And update the log to independent records for each server
1 parent d4507e3 commit 9e5ef89

File tree

6 files changed

+280
-247
lines changed

6 files changed

+280
-247
lines changed

bot.py

Lines changed: 249 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
import discord
2+
import sys
3+
import os
4+
import re
5+
import traceback
6+
import aiohttp
7+
import update
8+
import function as func
9+
import json
10+
import logging
11+
from zhconv import convert
12+
from discord.ext import commands
13+
from web import IPCServer
14+
from motor.motor_asyncio import AsyncIOMotorClient
15+
from datetime import datetime
16+
from voicelink import VoicelinkException
17+
from gpt.choose_act import choose_act
18+
from gpt.sendmessage import gpt_message, load_and_index_dialogue_history, save_vector_store, vector_store
19+
from logs import TimedRotatingFileHandler
20+
class Translator(discord.app_commands.Translator):
21+
async def load(self):
22+
print("Loaded Translator")
23+
24+
async def unload(self):
25+
print("Unload Translator")
26+
27+
async def translate(self, string: discord.app_commands.locale_str, locale: discord.Locale, context: discord.app_commands.TranslationContext):
28+
if str(locale) in func.LOCAL_LANGS:
29+
return func.LOCAL_LANGS[str(locale)].get(string.message, None)
30+
return None
31+
# 配置 logging
32+
def setup_logger(server_name):
33+
logger = logging.getLogger(server_name)
34+
logger.setLevel(logging.INFO)
35+
handler = TimedRotatingFileHandler(server_name)
36+
formatter = logging.Formatter('%(asctime)s %(levelname)s:%(message)s')
37+
handler.setFormatter(formatter)
38+
logger.addHandler(handler)
39+
return logger
40+
class PigPig(commands.Bot):
41+
def __init__(self, *args, **kwargs):
42+
super().__init__(*args, **kwargs)
43+
self.dialogue_history_file = './data/dialogue_history.json'
44+
self.vector_store_path = './data/vector_store'
45+
self.load_dialogue_history()
46+
load_and_index_dialogue_history(self.dialogue_history_file)
47+
self.ipc = IPCServer(
48+
self,
49+
host=func.settings.ipc_server["host"],
50+
port=func.settings.ipc_server["port"],
51+
sercet_key=func.tokens.sercet_key
52+
)
53+
self.loggers = {}
54+
55+
def setup_logger_for_guild(self, guild_name):
56+
if guild_name not in self.loggers:
57+
self.loggers[guild_name] = setup_logger(guild_name)
58+
59+
def get_logger_for_guild(self, guild_name):
60+
if guild_name in self.loggers:
61+
return self.loggers[guild_name]
62+
else:
63+
self.setup_logger_for_guild(guild_name)
64+
return self.loggers[guild_name]
65+
66+
def setup_logger_for_guild(self, guild_name):
67+
if guild_name not in self.loggers:
68+
self.loggers[guild_name] = setup_logger(guild_name)
69+
70+
def load_dialogue_history(self):
71+
"""從檔案中讀取對話歷史"""
72+
try:
73+
with open(self.dialogue_history_file, 'r', encoding='utf-8') as file:
74+
self.dialogue_history = json.load(file)
75+
except FileNotFoundError:
76+
self.dialogue_history = {}
77+
78+
def save_dialogue_history(self):
79+
"""將對話歷史保存到檔案中"""
80+
with open(self.dialogue_history_file, 'w', encoding='utf-8') as file:
81+
json.dump(self.dialogue_history, file, ensure_ascii=False, indent=4)
82+
save_vector_store(vector_store, self.vector_store_path)
83+
84+
async def on_message(self, message: discord.Message, /) -> None:
85+
if message.author.bot or not message.guild:
86+
return
87+
88+
guild_name = message.guild.name
89+
self.setup_logger_for_guild(guild_name)
90+
logger = self.loggers[guild_name]
91+
92+
logger.info(f'收到訊息: {message.content} (來自:伺服器:{message.guild},頻道:{message.channel.name},{message.author.name})')
93+
await self.process_commands(message)
94+
95+
channel_id = str(message.channel.id)
96+
if channel_id not in self.dialogue_history:
97+
self.dialogue_history[channel_id] = []
98+
99+
try:
100+
match = re.search(r"<@\d+>\s*(.*)", message.content)
101+
prompt = match.group(1)
102+
except AttributeError: # 如果正則表達式沒有匹配到,會拋出 AttributeError
103+
prompt = message.content
104+
105+
self.dialogue_history[channel_id].append({"role": "user", "content": prompt})
106+
# 實現生成回應的邏輯
107+
if self.user.id in message.raw_mentions and not message.mention_everyone:
108+
# 發送初始訊息
109+
message_to_edit = await message.reply("思考中...")
110+
try:
111+
execute_action = await choose_act(self,prompt, message, message_to_edit)
112+
await execute_action(message_to_edit, self.dialogue_history, channel_id, prompt, message)
113+
except Exception as e:
114+
print(e)
115+
self.save_dialogue_history()
116+
117+
async def on_message_edit(self, before: discord.Message, after: discord.Message):
118+
if before.author.bot or not before.guild:
119+
return
120+
121+
logger = self.get_logger_for_guild(before.guild.name)
122+
logger.info(
123+
f"訊息修改: 原訊息({before.content}) 新訊息({after.content}) 頻道:{before.channel.name}, 作者:{before.author}"
124+
)
125+
channel_id = str(after.channel.id)
126+
if channel_id not in self.dialogue_history:
127+
self.dialogue_history[channel_id] = []
128+
129+
try:
130+
match = re.search(r"<@\d+>\s*(.*)", after.content)
131+
prompt = match.group(1)
132+
except AttributeError: # 如果正則表達式沒有匹配到,會拋出 AttributeError
133+
prompt = after.content
134+
135+
self.dialogue_history[channel_id].append({"role": "user", "content": prompt})
136+
137+
# 實現生成回應的邏輯
138+
if self.user.id in after.raw_mentions and not after.mention_everyone:
139+
try:
140+
# Fetch the bot's previous reply
141+
async for msg in after.channel.history(limit=50):
142+
if msg.reference and msg.reference.message_id == before.id and msg.author.id == self.user.id:
143+
await msg.delete() # 删除之前的回复
144+
145+
message_to_edit = await after.reply("思考中...") # 创建新的回复
146+
execute_action = await choose_act(self,prompt, after, message_to_edit)
147+
await execute_action(message_to_edit, self.dialogue_history, channel_id, prompt, after)
148+
except Exception as e:
149+
print(e)
150+
self.save_dialogue_history()
151+
152+
153+
async def connect_db(self) -> None:
154+
if not ((db_name := func.tokens.mongodb_name) and (db_url := func.tokens.mongodb_url)):
155+
raise Exception("MONGODB_NAME and MONGODB_URL can't not be empty in settings.json")
156+
157+
try:
158+
func.MONGO_DB = AsyncIOMotorClient(host=db_url, serverSelectionTimeoutMS=5000)
159+
await func.MONGO_DB.server_info()
160+
print("Successfully connected to MongoDB!")
161+
162+
except Exception as e:
163+
raise Exception("Not able to connect MongoDB! Reason:", e)
164+
165+
func.SETTINGS_DB = func.MONGO_DB[db_name]["Settings"]
166+
func.USERS_DB = func.MONGO_DB[db_name]["Users"]
167+
168+
async def setup_hook(self) -> None:
169+
func.langs_setup()
170+
171+
# Connecting to MongoDB
172+
await self.connect_db()
173+
# Loading all the module in `cogs` folder
174+
for module in os.listdir(func.ROOT_DIR + '/cogs'):
175+
if module.endswith('.py'):
176+
try:
177+
await self.load_extension(f"cogs.{module[:-3]}")
178+
print(f"Loaded {module[:-3]}")
179+
except Exception as e:
180+
print(traceback.format_exc())
181+
182+
if func.settings.ipc_server.get("enable", False):
183+
await self.ipc.start()
184+
185+
if not func.settings.version or func.settings.version != update.__version__:
186+
func.update_json("settings.json", new_data={"version": update.__version__})
187+
188+
await self.tree.set_translator(Translator())
189+
await self.tree.sync()
190+
191+
async def on_ready(self):
192+
print("------------------")
193+
print(f"Logging As {self.user}")
194+
print(f"Bot ID: {self.user.id}")
195+
print("------------------")
196+
print(f"Discord Version: {discord.__version__}")
197+
print(f"Python Version: {sys.version}")
198+
print("------------------")
199+
data = {}
200+
data['guilds'] = []
201+
for guild in self.guilds:
202+
guild_info = {
203+
'guild_name': guild.name,'guild_id': guild.id,
204+
'channels': []
205+
}
206+
for channel in guild.channels:
207+
channel_info =f"channel_name: {channel.name},channel_id: {channel.id},channel_type: {str(channel.type)}"
208+
guild_info['channels'].append(channel_info)
209+
data['guilds'].append(guild_info)
210+
self.setup_logger_for_guild(guild.name) # 設置每個伺服器的 logger
211+
212+
# 將資料寫入 JSON 文件
213+
with open('logs/guilds_and_channels.json', 'w', encoding='utf-8') as f:
214+
json.dump(data, f, ensure_ascii=False, indent=4)
215+
print('update succesfully guilds_and_channels.json')
216+
func.tokens.client_id = self.user.id
217+
func.LOCAL_LANGS.clear()
218+
219+
async def on_command_error(self, ctx: commands.Context, exception, /) -> None:
220+
error = getattr(exception, 'original', exception)
221+
if ctx.interaction:
222+
error = getattr(error, 'original', error)
223+
if isinstance(error, (commands.CommandNotFound, aiohttp.client_exceptions.ClientOSError)):
224+
return
225+
226+
elif isinstance(error, (commands.CommandOnCooldown, commands.MissingPermissions, commands.RangeError, commands.BadArgument)):
227+
pass
228+
229+
elif isinstance(error, (commands.MissingRequiredArgument, commands.MissingRequiredAttachment)):
230+
command = f"{ctx.prefix}" + (f"{ctx.command.parent.qualified_name} " if ctx.command.parent else "") + f"{ctx.command.name} {ctx.command.signature}"
231+
position = command.find(f"<{ctx.current_parameter.name}>") + 1
232+
description = f"**Correct Usage:**\n```{command}\n" + " " * position + "^" * len(ctx.current_parameter.name) + "```\n"
233+
if ctx.command.aliases:
234+
description += f"**Aliases:**\n`{', '.join([f'{ctx.prefix}{alias}' for alias in ctx.command.aliases])}`\n\n"
235+
description += f"**Description:**\n{ctx.command.help}\n\u200b"
236+
237+
embed = discord.Embed(description=description, color=func.settings.embed_color)
238+
embed.set_footer(icon_url=ctx.me.display_avatar.url, text=f"More Help: {func.settings.invite_link}")
239+
return await ctx.reply(embed=embed)
240+
241+
elif not issubclass(error.__class__, VoicelinkException):
242+
error = func.get_lang(ctx.guild.id, "unknownException") + func.settings.invite_link
243+
if (guildId := ctx.guild.id) not in func.ERROR_LOGS:
244+
func.ERROR_LOGS[guildId][round(datetime.timestamp(datetime.now()))] = "".join(traceback.format_exception(type(exception), exception, exception.__traceback__))
245+
246+
try:
247+
return await ctx.reply(error, ephemeral=True)
248+
except:
249+
pass

gpt/choose_act.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import json
22
import aiohttp
3-
import logging
43
from gpt.gpt_response_gen import generate_response
54
from gpt.sendmessage import gpt_message
65
from gpt.vqa import vqa_answer
@@ -21,7 +20,7 @@ def internet_search(query: str, search_type: str):
2120
If the conversation contains a URL, select url
2221
Args:
2322
query (str): Query to search the web with
24-
search_type (str): Type of search to perform (one of [eat,url,general, image, youtube])
23+
search_type (str): Type of search to perform (one of [general,eat,url, image, youtube])
2524
"""
2625
pass
2726
```
@@ -115,7 +114,8 @@ def manage_user_data(user_id: str, user_data: str = None, action: str = 'read'):
115114
'''
116115
async def generate_image(message_to_edit, message,prompt: str, n_steps: int = 40, high_noise_frac: float = 0.8):
117116
await message_to_edit.edit(content="畫畫修練中")
118-
async def choose_act(prompt, message,message_to_edit):
117+
async def choose_act(bot,prompt, message,message_to_edit):
118+
logger = bot.get_logger_for_guild(message.guild.name)
119119
prompt = f"msgtime:[{str(datetime.now())[:-7]}]{prompt}"
120120
global system_prompt
121121
default_action_list = [
@@ -148,7 +148,7 @@ async def choose_act(prompt, message,message_to_edit):
148148
responses += response
149149
# 解析 JSON 字符串
150150
thread.join()
151-
#logging.info(responses)
151+
logger.info(responses)
152152
try:
153153
# 提取 JSON 部分
154154
json_start = responses.find("[")
@@ -162,9 +162,10 @@ async def choose_act(prompt, message,message_to_edit):
162162
action_list = default_action_list
163163

164164
async def execute_action(message_to_edit, dialogue_history, channel_id, original_prompt, message):
165+
logger = bot.get_logger_for_guild(message_to_edit.guild.name)
165166
nonlocal action_list, tool_func_dict
166167
final_results = []
167-
logging.info(action_list)
168+
logger.info(action_list)
168169
try:
169170
for action in action_list:
170171
tool_name = action["tool_name"]
@@ -181,13 +182,13 @@ async def execute_action(message_to_edit, dialogue_history, channel_id, original
181182
if result is not None and tool_name != "directly_answer":
182183
final_results.append(result)
183184
except Exception as e:
184-
logging.info(e)
185+
logger.info(e)
185186
else:
186-
logging.info(f"未知的工具函数: {tool_name}")
187+
logger.info(f"未知的工具函数: {tool_name}")
187188
finally:
188189
integrated_results = "\n".join(final_results)
189190
final_prompt = f'<<information:\n{integrated_results}\n{original_prompt}>>'
190191
gptresponses = await gpt_message(message_to_edit, message, final_prompt)
191192
dialogue_history[channel_id].append({"role": "assistant", "content": gptresponses})
192-
logging.info(f'PigPig:{gptresponses}')
193+
logger.info(f'PigPig:{gptresponses}')
193194
return execute_action

gpt/internet/google_search.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ async def google_search(message_to_edit,message,query):
2323
soup = BeautifulSoup(html, 'html.parser')
2424
search_results = soup.select('.g')
2525
search = ""
26-
for result in search_results[:5]:
26+
for result in search_results[:8]:
2727
title_element = result.select_one('h3')
2828
title = title_element.text if title_element else 'No Title'
2929
snippet_element = result.select_one('.VwiC3b')

gpt/sendmessage.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ async def gpt_message(message_to_edit,message,prompt):
7171
related_data = search_vector_database(prompt) # 使用 LangChain 搜尋相關資料
7272
# 讀取該訊息頻道最近的歷史紀錄
7373
history = []
74-
async for msg in channel.history(limit=10):
74+
async for msg in channel.history(limit=5):
7575
history.append(msg)
7676
history.reverse()
7777
history = history[:-2]

logs.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
11
import logging
22
import os
33
from datetime import datetime
4+
45
class TimedRotatingFileHandler(logging.Handler):
5-
def __init__(self):
6+
def __init__(self, server_name):
67
super().__init__()
8+
self.server_name = server_name
79
self.current_date = datetime.now().strftime('%Y%m%d')
810
self.current_hour = datetime.now().strftime('%H')
911
self._create_new_folder()
1012
self._open_new_file()
1113

1214
def _create_new_folder(self):
13-
log_directory = f'logs/{self.current_date}'
15+
log_directory = f'logs/{self.server_name}/{self.current_date}'
1416
if not os.path.exists(log_directory):
1517
os.makedirs(log_directory)
1618
self.log_directory = log_directory
@@ -22,20 +24,28 @@ def _open_new_file(self):
2224
def emit(self, record):
2325
current_date = datetime.now().strftime('%Y%m%d')
2426
current_hour = datetime.now().strftime('%H')
25-
if current_date != self.current_date:
27+
if current_date != self.current_date or current_hour != self.current_hour:
2628
self.stream.close()
2729
self.current_date = current_date
28-
self._create_new_folder()
29-
self.current_hour = current_hour
30-
self._open_new_file()
31-
elif current_hour != self.current_hour:
32-
self.stream.close()
3330
self.current_hour = current_hour
31+
self._create_new_folder()
3432
self._open_new_file()
3533
msg = self.format(record)
3634
self.stream.write(msg + '\n')
3735
self.stream.flush()
3836

3937
def close(self):
4038
self.stream.close()
41-
super().close()
39+
super().close()
40+
41+
def setup_logger(server_name):
42+
logger = logging.getLogger(server_name)
43+
logger.setLevel(logging.INFO)
44+
handler = TimedRotatingFileHandler(server_name)
45+
formatter = logging.Formatter('%(asctime)s %(levelname)s:%(message)s')
46+
handler.setFormatter(formatter)
47+
# 移除所有默認的處理程序
48+
if logger.hasHandlers():
49+
logger.handlers.clear()
50+
logger.addHandler(handler)
51+
return logger

0 commit comments

Comments
 (0)