Skip to content

Commit

Permalink
Merge pull request #106 from AmiyaBot/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
vivien8261 authored Jul 31, 2024
2 parents 658f730 + 702a2b9 commit a944b2e
Show file tree
Hide file tree
Showing 13 changed files with 97 additions and 65 deletions.
5 changes: 2 additions & 3 deletions amiyabot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,8 @@ def __init__(
if not appid:
appid = random_code(10)

super().__init__(appid, token, adapter)
super().__init__(appid, token, adapter, private)

self.private = private
self.send_message = self.instance.send_message

self.__closed = False
Expand All @@ -71,7 +70,7 @@ async def start(self, launch_browser: typing.Union[bool, BrowserLaunchConfig] =
await basic_browser_service.launch(BrowserLaunchConfig() if launch_browser is True else launch_browser)

self.run_timed_tasks()
await self.instance.start(self.private, self.__message_handler)
await self.instance.start(self.__message_handler)

async def close(self):
if not self.__closed:
Expand Down
9 changes: 6 additions & 3 deletions amiyabot/adapters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,23 @@


class BotAdapterProtocol:
def __init__(self, appid: str, token: str):
def __init__(self, appid: str, token: str, private: bool = False):
self.appid = appid
self.token = token
self.alive = False
self.keep_run = True

self.private = private

# 适配器实例连接信息
self.host: Optional[str] = None
self.ws_port: Optional[int] = None
self.http_port: Optional[int] = None
self.session: Optional[str] = None
self.headers: Optional[dict] = None

self.bot_name = ''

self.log = LoggerManager(self.__str__())
self.bot: Optional[T_BotHandlerFactory] = None

Expand Down Expand Up @@ -73,11 +77,10 @@ async def close(self):
raise NotImplementedError

@abc.abstractmethod
async def start(self, private: bool, handler: HANDLER_TYPE):
async def start(self, handler: HANDLER_TYPE):
"""
启动实例,执行 handler 方法处理消息
:param private: 是否私域机器人
:param handler: 消息处理方法
"""
raise NotImplementedError
Expand Down
2 changes: 1 addition & 1 deletion amiyabot/adapters/comwechat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class ComWeChatBotInstance(OneBot12Instance):
def __str__(self):
return 'ComWeChat'

async def start(self, private: bool, handler: HANDLER_TYPE):
async def start(self, handler: HANDLER_TYPE):
while self.keep_run:
await self.keep_connect(handler, package_method=package_com_wechat_message)
await asyncio.sleep(10)
Expand Down
2 changes: 1 addition & 1 deletion amiyabot/adapters/kook/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def api(self):
def __still_alive(self):
return self.keep_run and self.connection

async def start(self, private: bool, handler: HANDLER_TYPE):
async def start(self, handler: HANDLER_TYPE):
me_req = await self.api.get_me()
if me_req:
self.appid = me_req.json['data']['id']
Expand Down
2 changes: 1 addition & 1 deletion amiyabot/adapters/mirai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ async def close(self):
if self.connection:
await self.connection.close()

async def start(self, private: bool, handler: HANDLER_TYPE):
async def start(self, handler: HANDLER_TYPE):
while self.keep_run:
await self.keep_connect(handler)
await asyncio.sleep(10)
Expand Down
2 changes: 1 addition & 1 deletion amiyabot/adapters/onebot/v11/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ async def close(self):
if self.connection:
await self.connection.close()

async def start(self, private: bool, handler: HANDLER_TYPE):
async def start(self, handler: HANDLER_TYPE):
while self.keep_run:
await self.keep_connect(handler)
await asyncio.sleep(10)
Expand Down
2 changes: 1 addition & 1 deletion amiyabot/adapters/onebot/v12/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ async def close(self):
if self.connection:
await self.connection.close()

async def start(self, private: bool, handler: HANDLER_TYPE):
async def start(self, handler: HANDLER_TYPE):
while self.keep_run:
await self.keep_connect(handler)
await asyncio.sleep(10)
Expand Down
16 changes: 12 additions & 4 deletions amiyabot/adapters/tencent/qqGlobal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,18 @@


class QQGlobalBotInstance(QQGroupBotInstance):
def __init__(self, appid: str, token: str, client_secret: str, default_chain_builder: ChainBuilder):
super().__init__(appid, token, client_secret, default_chain_builder)

self.guild = QQGuildBotInstance(appid, token)
def __init__(
self,
appid: str,
token: str,
client_secret: str,
default_chain_builder: ChainBuilder,
shard_index: int,
shards: int,
):
super().__init__(appid, token, client_secret, default_chain_builder, shard_index, shards)

self.guild = QQGuildBotInstance(appid, token, shard_index, shards)

def __str__(self):
return 'QQGlobal'
Expand Down
20 changes: 15 additions & 5 deletions amiyabot/adapters/tencent/qqGroup/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,16 @@


class QQGroupBotInstance(QQGuildBotInstance):
def __init__(self, appid: str, token: str, client_secret: str, default_chain_builder: ChainBuilder):
super().__init__(appid, token)
def __init__(
self,
appid: str,
token: str,
client_secret: str,
default_chain_builder: ChainBuilder,
shard_index: int,
shards: int,
):
super().__init__(appid, token, shard_index, shards)

self.__access_token_api = QQGroupAPI(self.appid, self.token, client_secret)
self.__default_chain_builder = default_chain_builder
Expand All @@ -33,14 +41,16 @@ def build_adapter(
client_secret: str,
default_chain_builder: Optional[ChainBuilder] = None,
default_chain_builder_options: QQGroupChainBuilderOptions = QQGroupChainBuilderOptions(),
shard_index: int = 0,
shards: int = 1,
):
def adapter(appid: str, token: str):
if default_chain_builder:
cb = default_chain_builder
else:
cb = QQGroupChainBuilder(default_chain_builder_options)

return cls(appid, token, client_secret, cb)
return cls(appid, token, client_secret, cb, shard_index, shards)

return adapter

Expand All @@ -52,14 +62,14 @@ def api(self):
def package_method(self):
return package_qq_group_message

async def start(self, private: bool, handler: HANDLER_TYPE):
async def start(self, handler: HANDLER_TYPE):
if hasattr(self.__default_chain_builder, 'start'):
self.__default_chain_builder.start()

if not self.__seq_service.alive:
asyncio.create_task(self.__seq_service.run())

await super().start(private, handler)
await super().start(handler)

async def close(self):
await self.__seq_service.stop()
Expand Down
95 changes: 53 additions & 42 deletions amiyabot/adapters/tencent/qqGuild/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,28 @@
from amiyabot.adapters.tencent.intents import get_intents

from .api import QQGuildAPI, log
from .model import GateWay, Payload, ShardsRecord, ConnectionHandler
from .model import GateWay, Payload, ConnectionModel, ConnectionHandler
from .package import package_qq_guild_message
from .builder import build_message_send, QQGuildMessageCallback


def qq_guild_shards(shard_index: int, shards: int):
def adapter(appid: str, token: str):
return QQGuildBotInstance(appid, token, shard_index, shards)

return adapter


class QQGuildBotInstance(BotAdapterProtocol):
def __init__(self, appid: str, token: str):
def __init__(self, appid: str, token: str, shard_index: int = 0, shards: int = 1):
super().__init__(appid, token)

self.appid = appid
self.token = token
self.bot_name = ''
self.shard_index = shard_index
self.shards = shards

self.shards_record: Dict[int, ShardsRecord] = {}
self.model: Optional[ConnectionModel] = None

def __str__(self):
return 'QQGuild'
Expand All @@ -37,49 +45,61 @@ def api(self):
def package_method(self):
return package_qq_guild_message

def __create_heartbeat(self, websocket, interval: int, record: ShardsRecord):
def __create_heartbeat(self, websocket, interval: int):
heartbeat_key = random_code(10)
record.heartbeat_key = heartbeat_key
asyncio.create_task(self.heartbeat_interval(websocket, interval, record.shards_index, heartbeat_key))
self.model.heartbeat_key = heartbeat_key
asyncio.create_task(
self.heartbeat_interval(
websocket,
interval,
heartbeat_key,
)
)

async def close(self):
log.info(f'closing {self}(appid {self.appid})...')
self.keep_run = False

for _, item in self.shards_record.items():
if item.connection:
await item.connection.close()
if self.model:
await self.model.connection.close()

async def start(self, private: bool, handler: HANDLER_TYPE):
async def start(self, handler: HANDLER_TYPE):
log.info(f'requesting appid {self.appid} gateway')

resp = await self.api.gateway_bot()

if not resp or 'url' not in resp.json:
if self.keep_run:
await asyncio.sleep(10)
asyncio.create_task(self.start(private, handler))
asyncio.create_task(self.start(handler))
return False

gateway = GateWay(**resp.json)

log.info(
f'appid {self.appid} gateway resp: shards {gateway.shards}, remaining %d/%d'
f'appid {self.appid} gateway resp: shards {gateway.shards}, max_concurrency %d, remaining %d/%d'
% (
gateway.session_start_limit['max_concurrency'],
gateway.session_start_limit['remaining'],
gateway.session_start_limit['total'],
)
)

await self.create_connection(ConnectionHandler(private=private, gateway=gateway, message_handler=handler))
await self.create_connection(
ConnectionHandler(
private=self.private,
gateway=gateway,
message_handler=handler,
)
)

async def create_connection(self, handler: ConnectionHandler, shards_index: int = 0):
async def create_connection(self, handler: ConnectionHandler):
gateway = handler.gateway
sign = f'{self.appid} {shards_index + 1}/{gateway.shards}'
sign = f'{self.appid} {self.shard_index + 1}/{self.shards}'

async with self.get_websocket_connection(sign, gateway.url) as websocket:
if websocket:
self.shards_record[shards_index] = ShardsRecord(shards_index, connection=websocket)
self.model = ConnectionModel(connection=websocket)

while self.keep_run:
await asyncio.sleep(0)
Expand All @@ -94,19 +114,15 @@ async def create_connection(self, handler: ConnectionHandler, shards_index: int
f'connected({sign}): {self.bot_name}({self}-%s)'
% ('private' if handler.private else 'public')
)
self.shards_record[shards_index].session_id = payload.d['session_id']

if shards_index == 0 and gateway.shards > 1:
for n in range(gateway.shards - 1):
asyncio.create_task(self.create_connection(handler, n + 1))
self.model.session_id = payload.d['session_id']
else:
await self.create_package_task(handler, payload)

if payload.op == 10:
create_token = {
'token': f'Bot {self.appid}.{self.token}',
'intents': get_intents(handler.private, self.__str__()),
'shard': [shards_index, gateway.shards],
'shard': [self.shard_index, self.shards],
'properties': {
'$os': sys.platform,
'$browser': '',
Expand All @@ -115,25 +131,21 @@ async def create_connection(self, handler: ConnectionHandler, shards_index: int
}
await websocket.send(Payload(op=2, d=create_token).to_json())

self.__create_heartbeat(
websocket,
payload.d['heartbeat_interval'],
self.shards_record[shards_index],
)
self.__create_heartbeat(websocket, payload.d['heartbeat_interval'])

if payload.s:
self.shards_record[shards_index].last_s = payload.s
self.model.last_s = payload.s

while self.keep_run and self.shards_record[shards_index].reconnect_limit > 0:
await self.reconnect(handler, self.shards_record[shards_index], sign)
while self.keep_run and self.model.reconnect_limit > 0:
await self.reconnect(handler, sign)
await asyncio.sleep(1)

async def reconnect(self, handler: ConnectionHandler, record: ShardsRecord, sign: str):
async def reconnect(self, handler: ConnectionHandler, sign: str):
log.info(f'reconnecting({sign})...')

async with self.get_websocket_connection(sign, handler.gateway.url) as websocket:
if websocket:
record.connection = websocket
self.model.connection = websocket

while self.keep_run:
await asyncio.sleep(0)
Expand All @@ -150,34 +162,33 @@ async def reconnect(self, handler: ConnectionHandler, record: ShardsRecord, sign
if payload.op == 10:
reconnect_token = {
'token': f'Bot {self.appid}.{self.token}',
'session_id': record.session_id,
'seq': record.last_s,
'session_id': self.model.session_id,
'seq': self.model.last_s,
}
await websocket.send(Payload(op=6, d=reconnect_token).to_json())

self.__create_heartbeat(websocket, payload.d['heartbeat_interval'], record)
self.__create_heartbeat(websocket, payload.d['heartbeat_interval'])

record.reconnect_limit = 3
self.model.reconnect_limit = 3

if payload.s:
record.last_s = payload.s
self.model.last_s = payload.s

record.reconnect_limit -= 1
self.model.reconnect_limit -= 1

async def heartbeat_interval(
self,
websocket: WebSocketClientProtocol,
interval: int,
shards_index: int,
heartbeat_key: str,
):
sec = 0
while self.keep_run and self.shards_record[shards_index].heartbeat_key == heartbeat_key:
while self.keep_run and self.model.heartbeat_key == heartbeat_key:
await asyncio.sleep(1)
sec += 1
if sec >= interval / 1000:
sec = 0
await websocket.send(Payload(op=1, d=self.shards_record[shards_index].last_s).to_json())
await websocket.send(Payload(op=1, d=self.model.last_s).to_json())

async def create_package_task(self, handler: ConnectionHandler, payload: Payload):
asyncio.create_task(
Expand Down
3 changes: 1 addition & 2 deletions amiyabot/adapters/tencent/qqGuild/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@ class ConnectionHandler:


@dataclass
class ShardsRecord:
shards_index: int
class ConnectionModel:
session_id: Optional[str] = None
last_s: Optional[int] = None
reconnect_limit: int = 3
Expand Down
Loading

0 comments on commit a944b2e

Please sign in to comment.