1
+ import discord
2
+ import sys
3
+ import os
4
+ import re
5
+ import traceback
6
+ import aiohttp
7
+ import update
8
+ import function as func
9
+ import json
10
+ import logging
11
+ from zhconv import convert
12
+ from discord .ext import commands
13
+ from web import IPCServer
14
+ from motor .motor_asyncio import AsyncIOMotorClient
15
+ from datetime import datetime
16
+ from voicelink import VoicelinkException
17
+ from gpt .choose_act import choose_act
18
+ from gpt .sendmessage import gpt_message , load_and_index_dialogue_history , save_vector_store , vector_store
19
+ from logs import TimedRotatingFileHandler
20
+ class Translator (discord .app_commands .Translator ):
21
+ async def load (self ):
22
+ print ("Loaded Translator" )
23
+
24
+ async def unload (self ):
25
+ print ("Unload Translator" )
26
+
27
+ async def translate (self , string : discord .app_commands .locale_str , locale : discord .Locale , context : discord .app_commands .TranslationContext ):
28
+ if str (locale ) in func .LOCAL_LANGS :
29
+ return func .LOCAL_LANGS [str (locale )].get (string .message , None )
30
+ return None
31
+ # 配置 logging
32
+ def setup_logger (server_name ):
33
+ logger = logging .getLogger (server_name )
34
+ logger .setLevel (logging .INFO )
35
+ handler = TimedRotatingFileHandler (server_name )
36
+ formatter = logging .Formatter ('%(asctime)s %(levelname)s:%(message)s' )
37
+ handler .setFormatter (formatter )
38
+ logger .addHandler (handler )
39
+ return logger
40
+ class PigPig (commands .Bot ):
41
+ def __init__ (self , * args , ** kwargs ):
42
+ super ().__init__ (* args , ** kwargs )
43
+ self .dialogue_history_file = './data/dialogue_history.json'
44
+ self .vector_store_path = './data/vector_store'
45
+ self .load_dialogue_history ()
46
+ load_and_index_dialogue_history (self .dialogue_history_file )
47
+ self .ipc = IPCServer (
48
+ self ,
49
+ host = func .settings .ipc_server ["host" ],
50
+ port = func .settings .ipc_server ["port" ],
51
+ sercet_key = func .tokens .sercet_key
52
+ )
53
+ self .loggers = {}
54
+
55
+ def setup_logger_for_guild (self , guild_name ):
56
+ if guild_name not in self .loggers :
57
+ self .loggers [guild_name ] = setup_logger (guild_name )
58
+
59
+ def get_logger_for_guild (self , guild_name ):
60
+ if guild_name in self .loggers :
61
+ return self .loggers [guild_name ]
62
+ else :
63
+ self .setup_logger_for_guild (guild_name )
64
+ return self .loggers [guild_name ]
65
+
66
+ def setup_logger_for_guild (self , guild_name ):
67
+ if guild_name not in self .loggers :
68
+ self .loggers [guild_name ] = setup_logger (guild_name )
69
+
70
+ def load_dialogue_history (self ):
71
+ """從檔案中讀取對話歷史"""
72
+ try :
73
+ with open (self .dialogue_history_file , 'r' , encoding = 'utf-8' ) as file :
74
+ self .dialogue_history = json .load (file )
75
+ except FileNotFoundError :
76
+ self .dialogue_history = {}
77
+
78
+ def save_dialogue_history (self ):
79
+ """將對話歷史保存到檔案中"""
80
+ with open (self .dialogue_history_file , 'w' , encoding = 'utf-8' ) as file :
81
+ json .dump (self .dialogue_history , file , ensure_ascii = False , indent = 4 )
82
+ save_vector_store (vector_store , self .vector_store_path )
83
+
84
+ async def on_message (self , message : discord .Message , / ) -> None :
85
+ if message .author .bot or not message .guild :
86
+ return
87
+
88
+ guild_name = message .guild .name
89
+ self .setup_logger_for_guild (guild_name )
90
+ logger = self .loggers [guild_name ]
91
+
92
+ logger .info (f'收到訊息: { message .content } (來自:伺服器:{ message .guild } ,頻道:{ message .channel .name } ,{ message .author .name } )' )
93
+ await self .process_commands (message )
94
+
95
+ channel_id = str (message .channel .id )
96
+ if channel_id not in self .dialogue_history :
97
+ self .dialogue_history [channel_id ] = []
98
+
99
+ try :
100
+ match = re .search (r"<@\d+>\s*(.*)" , message .content )
101
+ prompt = match .group (1 )
102
+ except AttributeError : # 如果正則表達式沒有匹配到,會拋出 AttributeError
103
+ prompt = message .content
104
+
105
+ self .dialogue_history [channel_id ].append ({"role" : "user" , "content" : prompt })
106
+ # 實現生成回應的邏輯
107
+ if self .user .id in message .raw_mentions and not message .mention_everyone :
108
+ # 發送初始訊息
109
+ message_to_edit = await message .reply ("思考中..." )
110
+ try :
111
+ execute_action = await choose_act (self ,prompt , message , message_to_edit )
112
+ await execute_action (message_to_edit , self .dialogue_history , channel_id , prompt , message )
113
+ except Exception as e :
114
+ print (e )
115
+ self .save_dialogue_history ()
116
+
117
+ async def on_message_edit (self , before : discord .Message , after : discord .Message ):
118
+ if before .author .bot or not before .guild :
119
+ return
120
+
121
+ logger = self .get_logger_for_guild (before .guild .name )
122
+ logger .info (
123
+ f"訊息修改: 原訊息({ before .content } ) 新訊息({ after .content } ) 頻道:{ before .channel .name } , 作者:{ before .author } "
124
+ )
125
+ channel_id = str (after .channel .id )
126
+ if channel_id not in self .dialogue_history :
127
+ self .dialogue_history [channel_id ] = []
128
+
129
+ try :
130
+ match = re .search (r"<@\d+>\s*(.*)" , after .content )
131
+ prompt = match .group (1 )
132
+ except AttributeError : # 如果正則表達式沒有匹配到,會拋出 AttributeError
133
+ prompt = after .content
134
+
135
+ self .dialogue_history [channel_id ].append ({"role" : "user" , "content" : prompt })
136
+
137
+ # 實現生成回應的邏輯
138
+ if self .user .id in after .raw_mentions and not after .mention_everyone :
139
+ try :
140
+ # Fetch the bot's previous reply
141
+ async for msg in after .channel .history (limit = 50 ):
142
+ if msg .reference and msg .reference .message_id == before .id and msg .author .id == self .user .id :
143
+ await msg .delete () # 删除之前的回复
144
+
145
+ message_to_edit = await after .reply ("思考中..." ) # 创建新的回复
146
+ execute_action = await choose_act (self ,prompt , after , message_to_edit )
147
+ await execute_action (message_to_edit , self .dialogue_history , channel_id , prompt , after )
148
+ except Exception as e :
149
+ print (e )
150
+ self .save_dialogue_history ()
151
+
152
+
153
+ async def connect_db (self ) -> None :
154
+ if not ((db_name := func .tokens .mongodb_name ) and (db_url := func .tokens .mongodb_url )):
155
+ raise Exception ("MONGODB_NAME and MONGODB_URL can't not be empty in settings.json" )
156
+
157
+ try :
158
+ func .MONGO_DB = AsyncIOMotorClient (host = db_url , serverSelectionTimeoutMS = 5000 )
159
+ await func .MONGO_DB .server_info ()
160
+ print ("Successfully connected to MongoDB!" )
161
+
162
+ except Exception as e :
163
+ raise Exception ("Not able to connect MongoDB! Reason:" , e )
164
+
165
+ func .SETTINGS_DB = func .MONGO_DB [db_name ]["Settings" ]
166
+ func .USERS_DB = func .MONGO_DB [db_name ]["Users" ]
167
+
168
+ async def setup_hook (self ) -> None :
169
+ func .langs_setup ()
170
+
171
+ # Connecting to MongoDB
172
+ await self .connect_db ()
173
+ # Loading all the module in `cogs` folder
174
+ for module in os .listdir (func .ROOT_DIR + '/cogs' ):
175
+ if module .endswith ('.py' ):
176
+ try :
177
+ await self .load_extension (f"cogs.{ module [:- 3 ]} " )
178
+ print (f"Loaded { module [:- 3 ]} " )
179
+ except Exception as e :
180
+ print (traceback .format_exc ())
181
+
182
+ if func .settings .ipc_server .get ("enable" , False ):
183
+ await self .ipc .start ()
184
+
185
+ if not func .settings .version or func .settings .version != update .__version__ :
186
+ func .update_json ("settings.json" , new_data = {"version" : update .__version__ })
187
+
188
+ await self .tree .set_translator (Translator ())
189
+ await self .tree .sync ()
190
+
191
+ async def on_ready (self ):
192
+ print ("------------------" )
193
+ print (f"Logging As { self .user } " )
194
+ print (f"Bot ID: { self .user .id } " )
195
+ print ("------------------" )
196
+ print (f"Discord Version: { discord .__version__ } " )
197
+ print (f"Python Version: { sys .version } " )
198
+ print ("------------------" )
199
+ data = {}
200
+ data ['guilds' ] = []
201
+ for guild in self .guilds :
202
+ guild_info = {
203
+ 'guild_name' : guild .name ,'guild_id' : guild .id ,
204
+ 'channels' : []
205
+ }
206
+ for channel in guild .channels :
207
+ channel_info = f"channel_name: { channel .name } ,channel_id: { channel .id } ,channel_type: { str (channel .type )} "
208
+ guild_info ['channels' ].append (channel_info )
209
+ data ['guilds' ].append (guild_info )
210
+ self .setup_logger_for_guild (guild .name ) # 設置每個伺服器的 logger
211
+
212
+ # 將資料寫入 JSON 文件
213
+ with open ('logs/guilds_and_channels.json' , 'w' , encoding = 'utf-8' ) as f :
214
+ json .dump (data , f , ensure_ascii = False , indent = 4 )
215
+ print ('update succesfully guilds_and_channels.json' )
216
+ func .tokens .client_id = self .user .id
217
+ func .LOCAL_LANGS .clear ()
218
+
219
+ async def on_command_error (self , ctx : commands .Context , exception , / ) -> None :
220
+ error = getattr (exception , 'original' , exception )
221
+ if ctx .interaction :
222
+ error = getattr (error , 'original' , error )
223
+ if isinstance (error , (commands .CommandNotFound , aiohttp .client_exceptions .ClientOSError )):
224
+ return
225
+
226
+ elif isinstance (error , (commands .CommandOnCooldown , commands .MissingPermissions , commands .RangeError , commands .BadArgument )):
227
+ pass
228
+
229
+ elif isinstance (error , (commands .MissingRequiredArgument , commands .MissingRequiredAttachment )):
230
+ command = f"{ ctx .prefix } " + (f"{ ctx .command .parent .qualified_name } " if ctx .command .parent else "" ) + f"{ ctx .command .name } { ctx .command .signature } "
231
+ position = command .find (f"<{ ctx .current_parameter .name } >" ) + 1
232
+ description = f"**Correct Usage:**\n ```{ command } \n " + " " * position + "^" * len (ctx .current_parameter .name ) + "```\n "
233
+ if ctx .command .aliases :
234
+ description += f"**Aliases:**\n `{ ', ' .join ([f'{ ctx .prefix } { alias } ' for alias in ctx .command .aliases ])} `\n \n "
235
+ description += f"**Description:**\n { ctx .command .help } \n \u200b "
236
+
237
+ embed = discord .Embed (description = description , color = func .settings .embed_color )
238
+ embed .set_footer (icon_url = ctx .me .display_avatar .url , text = f"More Help: { func .settings .invite_link } " )
239
+ return await ctx .reply (embed = embed )
240
+
241
+ elif not issubclass (error .__class__ , VoicelinkException ):
242
+ error = func .get_lang (ctx .guild .id , "unknownException" ) + func .settings .invite_link
243
+ if (guildId := ctx .guild .id ) not in func .ERROR_LOGS :
244
+ func .ERROR_LOGS [guildId ][round (datetime .timestamp (datetime .now ()))] = "" .join (traceback .format_exception (type (exception ), exception , exception .__traceback__ ))
245
+
246
+ try :
247
+ return await ctx .reply (error , ephemeral = True )
248
+ except :
249
+ pass
0 commit comments