Skip to content

Commit

Permalink
Merge pull request #101 from AmiyaBot/dev
Browse files Browse the repository at this point in the history
fix: 修复 BUG
  • Loading branch information
vivien8261 authored Jul 1, 2024
2 parents 2ee880c + 467628b commit 63af1df
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 24 deletions.
11 changes: 11 additions & 0 deletions amiyabot/builtin/message/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,12 +138,23 @@ def __init__(self, result: bool, weight: Union[int, float] = 0, keypoint: Option
self.weight = weight
self.keypoint = keypoint

self.on_selected: Optional[Callable] = None

def __bool__(self):
return bool(self.result)

def __repr__(self):
return f'<Verify, {self.result}, {self.weight}>'

def set_attrs(self, *attrs: Any):
indexes = [
'result',
'weight',
'keypoint',
]
for index, value in zip(indexes, attrs):
setattr(self, index, value)


@dataclass
class File:
Expand Down
73 changes: 49 additions & 24 deletions amiyabot/factory/implemented.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import re

from typing import List
from contextlib import contextmanager
from dataclasses import dataclass
from amiyabot.util import remove_prefix_once
from amiyabot.builtin.message import Message, MessageMatch, Verify, Equal
Expand All @@ -8,7 +10,7 @@

@dataclass
class MessageHandlerItemImpl(MessageHandlerItem):
def __check(self, data: Message, obj: KeywordsType) -> Verify:
def __check(self, result: Verify, data: Message, obj: KeywordsType):
methods = {
str: MessageMatch.check_str,
Equal: MessageMatch.check_equal,
Expand All @@ -18,33 +20,54 @@ def __check(self, data: Message, obj: KeywordsType) -> Verify:

if t in methods:
method = methods[t]
check = Verify(*method(data, obj, self.level))
check = result.set_attrs(*method(data, obj, self.level))
if check:
return check

elif t is list:
for item in obj:
check = self.__check(data, item)
check = self.__check(result, data, item)
if check:
return check

return Verify(False)
return result

@classmethod
def update_data(cls, data: Message, prefix_keywords: List[str]):
def func():
text, prefix = remove_prefix_once(data.text, prefix_keywords)
if prefix:
data.text_prefix = prefix
data.set_text(text, set_original=False)

return func

@classmethod
@contextmanager
def restore_data(cls, result: Verify, data: Message):
if result.on_selected:
result.on_selected()
yield
data.text_prefix = ''
data.set_text(data.text_original, set_original=False)

async def verify(self, data: Message):
result = Verify(False)

# 检查是否支持私信
direct_only = self.direct_only or (self.group_config and self.group_config.direct_only)

if data.is_direct:
if not direct_only:
if self.allow_direct is None:
if not self.group_config or not self.group_config.allow_direct:
return Verify(False)
return result

if self.allow_direct is False:
return Verify(False)
return result
else:
if direct_only:
return Verify(False)
return result

# 检查是否包含前缀触发词或被 @
flag = False
Expand All @@ -62,13 +85,11 @@ async def verify(self, data: Message):

if not prefix_keywords:
flag = True

# 如果前缀校验通过,再次修正 Message 对象的属性值
text, prefix = remove_prefix_once(data.text, prefix_keywords)
if prefix:
flag = True
data.text_prefix = prefix
data.set_text(text, set_original=False)
else:
_, prefix = remove_prefix_once(data.text, prefix_keywords)
if prefix:
flag = True
result.on_selected = self.update_data(data, prefix_keywords)

# 若不通过以上检查,且关键字不为全等句式(Equal)
# 则允许当关键字为列表时,筛选列表内的全等句式继续执行校验,否则校验不通过
Expand All @@ -77,23 +98,27 @@ async def verify(self, data: Message):
if equal_filter:
self.keywords = equal_filter
else:
return Verify(False)
return result

# 执行自定义校验并修正其返回值
if self.custom_verify:
result = await self.custom_verify(data)
with self.restore_data(result, data):
res = await self.custom_verify(data)

if isinstance(res, bool) or res is None:
result.result = bool(res)
result.weight = int(bool(res))

if isinstance(result, bool) or result is None:
result = result, int(bool(result)), None
elif isinstance(res, tuple):
contrast = bool(res[0]), int(bool(res[0])), None
res_len = len(res)
res = (res + contrast[res_len:])[:3]

elif isinstance(result, tuple):
contrast = bool(result[0]), int(bool(result[0])), None
result_len = len(result)
result = (result + contrast[result_len:])[:3]
result.set_attrs(*res)

return Verify(*result)
return result

return self.__check(data, self.keywords)
return self.__check(result, data, self.keywords)

async def action(self, data: Message):
return await self.function(data)
2 changes: 2 additions & 0 deletions amiyabot/handler/messageHandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,8 @@ async def choice_handlers(data: Message, handlers: List[MessageHandlerItem], wai

# 将 Verify 结果赋值给 Message
data.verify = selected[0]
if data.verify.on_selected:
data.verify.on_selected()

return selected[1]

Expand Down

0 comments on commit 63af1df

Please sign in to comment.