Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: 存储 matcher 发送 prompt 的结果 #3155

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions nonebot/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@
"""当前 `reject` 目标存储 key"""
REJECT_CACHE_TARGET: Literal["_next_target"] = "_next_target"
"""下一个 `reject` 目标存储 key"""
PAUSE_PROMPT_RESULT_KEY: Literal["_pause_{key}_result"] = "_pause_{key}_result"
"""`pause` prompt 发送结果存储 key"""
REJECT_PROMPT_RESULT_KEY: Literal["_reject_{key}_result"] = "_reject_{key}_result"
"""`reject` prompt 发送结果存储 key"""

# used by Rule
PREFIX_KEY: Literal["_prefix"] = "_prefix"
Expand Down
34 changes: 25 additions & 9 deletions nonebot/internal/matcher/matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@
from nonebot.consts import (
ARG_KEY,
LAST_RECEIVE_KEY,
PAUSE_PROMPT_RESULT_KEY,
RECEIVE_KEY,
REJECT_CACHE_TARGET,
REJECT_PROMPT_RESULT_KEY,
REJECT_TARGET,
)
from nonebot.dependencies import Dependent, Param
Expand Down Expand Up @@ -560,8 +562,8 @@
"""
bot = current_bot.get()
event = current_event.get()
state = current_matcher.get().state
if isinstance(message, MessageTemplate):
state = current_matcher.get().state

Check warning on line 566 in nonebot/internal/matcher/matcher.py

View check run for this annotation

Codecov / codecov/patch

nonebot/internal/matcher/matcher.py#L566

Added line #L566 was not covered by tests
_message = message.format(**state)
else:
_message = message
Expand Down Expand Up @@ -597,8 +599,15 @@
kwargs: {ref}`nonebot.adapters.Bot.send` 的参数,
请参考对应 adapter 的 bot 对象 api
"""
try:
matcher = current_matcher.get()
except Exception:
matcher = None

Check warning on line 605 in nonebot/internal/matcher/matcher.py

View check run for this annotation

Codecov / codecov/patch

nonebot/internal/matcher/matcher.py#L604-L605

Added lines #L604 - L605 were not covered by tests

if prompt is not None:
await cls.send(prompt, **kwargs)
result = await cls.send(prompt, **kwargs)
if matcher is not None:
matcher.state[PAUSE_PROMPT_RESULT_KEY] = result
raise PausedException

@classmethod
Expand All @@ -615,8 +624,19 @@
kwargs: {ref}`nonebot.adapters.Bot.send` 的参数,
请参考对应 adapter 的 bot 对象 api
"""
try:
matcher = current_matcher.get()
key = matcher.get_target()
except Exception:
matcher = None
key = None

Check warning on line 632 in nonebot/internal/matcher/matcher.py

View check run for this annotation

Codecov / codecov/patch

nonebot/internal/matcher/matcher.py#L630-L632

Added lines #L630 - L632 were not covered by tests

key = REJECT_PROMPT_RESULT_KEY.format(key=key) if key is not None else None

if prompt is not None:
await cls.send(prompt, **kwargs)
result = await cls.send(prompt, **kwargs)
if key is not None and matcher:
matcher.state[key] = result
raise RejectedException

@classmethod
Expand All @@ -637,9 +657,7 @@
"""
matcher = current_matcher.get()
matcher.set_target(ARG_KEY.format(key=key))
if prompt is not None:
await cls.send(prompt, **kwargs)
raise RejectedException
await cls.reject(prompt, **kwargs)

@classmethod
async def reject_receive(
Expand All @@ -659,9 +677,7 @@
"""
matcher = current_matcher.get()
matcher.set_target(RECEIVE_KEY.format(id=id))
if prompt is not None:
await cls.send(prompt, **kwargs)
raise RejectedException
await cls.reject(prompt, **kwargs)

@classmethod
def skip(cls) -> NoReturn:
Expand Down
43 changes: 33 additions & 10 deletions nonebot/internal/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from pydantic.fields import FieldInfo as PydanticFieldInfo

from nonebot.compat import FieldInfo, ModelField, PydanticUndefined, extract_field_info
from nonebot.consts import ARG_KEY, REJECT_PROMPT_RESULT_KEY
from nonebot.dependencies import Dependent, Param
from nonebot.dependencies.utils import check_field_type
from nonebot.exception import SkippedException
Expand All @@ -39,7 +40,7 @@
)

if TYPE_CHECKING:
from nonebot.adapters import Bot, Event
from nonebot.adapters import Bot, Event, Message
from nonebot.matcher import Matcher


Expand Down Expand Up @@ -522,10 +523,10 @@

class ArgInner:
def __init__(
self, key: Optional[str], type: Literal["message", "str", "plaintext"]
self, key: Optional[str], type: Literal["message", "str", "plaintext", "prompt"]
) -> None:
self.key: Optional[str] = key
self.type: Literal["message", "str", "plaintext"] = type
self.type: Literal["message", "str", "plaintext", "prompt"] = type

def __repr__(self) -> str:
return f"ArgInner(key={self.key!r}, type={self.type!r})"
Expand All @@ -546,6 +547,11 @@
return ArgInner(key, "plaintext") # type: ignore


def ArgPromptResult(key: Optional[str] = None) -> Any:
"""`arg` prompt 发送结果"""
return ArgInner(key, "prompt")


class ArgParam(Param):
"""Arg 注入参数

Expand All @@ -559,7 +565,7 @@
self,
*args,
key: str,
type: Literal["message", "str", "plaintext"],
type: Literal["message", "str", "plaintext", "prompt"],
**kwargs: Any,
) -> None:
super().__init__(*args, **kwargs)
Expand All @@ -584,15 +590,32 @@
async def _solve( # pyright: ignore[reportIncompatibleMethodOverride]
self, matcher: "Matcher", **kwargs: Any
) -> Any:
message = matcher.get_arg(self.key)
if message is None:
return message
if self.type == "message":
return message
return self._solve_message(matcher)
elif self.type == "str":
return str(message)
return self._solve_str(matcher)
elif self.type == "plaintext":
return self._solve_plaintext(matcher)
elif self.type == "prompt":
return self._solve_prompt(matcher)
else:
return message.extract_plain_text()
raise ValueError(f"Unknown Arg type: {self.type}")

Check warning on line 602 in nonebot/internal/params.py

View check run for this annotation

Codecov / codecov/patch

nonebot/internal/params.py#L602

Added line #L602 was not covered by tests

def _solve_message(self, matcher: "Matcher") -> Optional["Message"]:
return matcher.get_arg(self.key)

def _solve_str(self, matcher: "Matcher") -> Optional[str]:
message = matcher.get_arg(self.key)
return str(message) if message is not None else None

def _solve_plaintext(self, matcher: "Matcher") -> Optional[str]:
message = matcher.get_arg(self.key)
return message.extract_plain_text() if message is not None else None

def _solve_prompt(self, matcher: "Matcher") -> Optional[Any]:
return matcher.state.get(
REJECT_PROMPT_RESULT_KEY.format(key=ARG_KEY.format(key=self.key))
)


class ExceptionParam(Param):
Expand Down
25 changes: 25 additions & 0 deletions nonebot/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,20 @@
ENDSWITH_KEY,
FULLMATCH_KEY,
KEYWORD_KEY,
PAUSE_PROMPT_RESULT_KEY,
PREFIX_KEY,
RAW_CMD_KEY,
RECEIVE_KEY,
REGEX_MATCHED,
REJECT_PROMPT_RESULT_KEY,
SHELL_ARGS,
SHELL_ARGV,
STARTSWITH_KEY,
)
from nonebot.internal.params import Arg as Arg
from nonebot.internal.params import ArgParam as ArgParam
from nonebot.internal.params import ArgPlainText as ArgPlainText
from nonebot.internal.params import ArgPromptResult as ArgPromptResult
from nonebot.internal.params import ArgStr as ArgStr
from nonebot.internal.params import BotParam as BotParam
from nonebot.internal.params import DefaultParam as DefaultParam
Expand Down Expand Up @@ -252,6 +256,26 @@ def _last_received(matcher: "Matcher") -> Any:
return Depends(_last_received, use_cache=False)


def ReceivePromptResult(id: Optional[str] = None) -> Any:
"""`receive` prompt 发送结果"""

def _receive_prompt_result(matcher: "Matcher") -> Any:
return matcher.state.get(
REJECT_PROMPT_RESULT_KEY.format(key=RECEIVE_KEY.format(id=id))
)

return Depends(_receive_prompt_result, use_cache=False)


def PausePromptResult() -> Any:
"""`pause` prompt 发送结果"""

def _pause_prompt_result(matcher: "Matcher") -> Any:
return matcher.state.get(PAUSE_PROMPT_RESULT_KEY)

return Depends(_pause_prompt_result, use_cache=False)


__autodoc__ = {
"Arg": True,
"ArgStr": True,
Expand All @@ -265,4 +289,5 @@ def _last_received(matcher: "Matcher") -> Any:
"DefaultParam": True,
"MatcherParam": True,
"ExceptionParam": True,
"ArgPromptResult": True,
}
8 changes: 6 additions & 2 deletions tests/plugins/param/param_arg.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Annotated
from typing import Annotated, Any

from nonebot.adapters import Message
from nonebot.params import Arg, ArgPlainText, ArgStr
from nonebot.params import Arg, ArgPlainText, ArgPromptResult, ArgStr


async def arg(key: Message = Arg()) -> Message:
Expand All @@ -28,6 +28,10 @@ async def annotated_arg_plain_text(key: Annotated[str, ArgPlainText()]) -> str:
return key


async def annotated_arg_prompt_result(key: Annotated[Any, ArgPromptResult()]) -> Any:
return key


# test dependency priority
async def annotated_prior_arg(key: Annotated[str, ArgStr("foo")] = ArgPlainText()):
return key
Expand Down
17 changes: 15 additions & 2 deletions tests/plugins/param/param_matcher.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
from typing import TypeVar, Union
from typing import Any, TypeVar, Union

from nonebot.adapters import Event
from nonebot.matcher import Matcher
from nonebot.params import LastReceived, Received
from nonebot.params import (
LastReceived,
PausePromptResult,
Received,
ReceivePromptResult,
)


async def matcher(m: Matcher) -> Matcher:
Expand Down Expand Up @@ -59,3 +64,11 @@ async def receive(e: Event = Received("test")) -> Event:

async def last_receive(e: Event = LastReceived()) -> Event:
return e


async def receive_prompt_result(result: Any = ReceivePromptResult("test")) -> Any:
return result


async def pause_prompt_result(result: Any = PausePromptResult()) -> Any:
return result
32 changes: 32 additions & 0 deletions tests/test_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,20 @@
import pytest

from nonebot.consts import (
ARG_KEY,
CMD_ARG_KEY,
CMD_KEY,
CMD_START_KEY,
CMD_WHITESPACE_KEY,
ENDSWITH_KEY,
FULLMATCH_KEY,
KEYWORD_KEY,
PAUSE_PROMPT_RESULT_KEY,
PREFIX_KEY,
RAW_CMD_KEY,
RECEIVE_KEY,
REGEX_MATCHED,
REJECT_PROMPT_RESULT_KEY,
SHELL_ARGS,
SHELL_ARGV,
STARTSWITH_KEY,
Expand Down Expand Up @@ -469,8 +473,10 @@ async def test_matcher(app: App):
matcher,
not_legacy_matcher,
not_matcher,
pause_prompt_result,
postpone_matcher,
receive,
receive_prompt_result,
sub_matcher,
union_matcher,
)
Expand Down Expand Up @@ -538,12 +544,31 @@ async def test_matcher(app: App):
ctx.pass_params(matcher=fake_matcher)
ctx.should_return(event_next)

fake_matcher.state[
REJECT_PROMPT_RESULT_KEY.format(key=RECEIVE_KEY.format(id="test"))
] = True

async with app.test_dependent(
receive_prompt_result, allow_types=[MatcherParam, DependParam]
) as ctx:
ctx.pass_params(matcher=fake_matcher)
ctx.should_return(True)

fake_matcher.state[PAUSE_PROMPT_RESULT_KEY] = True

async with app.test_dependent(
pause_prompt_result, allow_types=[MatcherParam, DependParam]
) as ctx:
ctx.pass_params(matcher=fake_matcher)
ctx.should_return(True)


@pytest.mark.anyio
async def test_arg(app: App):
from plugins.param.param_arg import (
annotated_arg,
annotated_arg_plain_text,
annotated_arg_prompt_result,
annotated_arg_str,
annotated_multi_arg,
annotated_prior_arg,
Expand All @@ -555,6 +580,7 @@ async def test_arg(app: App):
matcher = Matcher()
message = FakeMessage("text")
matcher.set_arg("key", message)
matcher.state[REJECT_PROMPT_RESULT_KEY.format(key=ARG_KEY.format(key="key"))] = True

async with app.test_dependent(arg, allow_types=[ArgParam]) as ctx:
ctx.pass_params(matcher=matcher)
Expand Down Expand Up @@ -582,6 +608,12 @@ async def test_arg(app: App):
ctx.pass_params(matcher=matcher)
ctx.should_return(message.extract_plain_text())

async with app.test_dependent(
annotated_arg_prompt_result, allow_types=[ArgParam]
) as ctx:
ctx.pass_params(matcher=matcher)
ctx.should_return(True)

async with app.test_dependent(annotated_multi_arg, allow_types=[ArgParam]) as ctx:
ctx.pass_params(matcher=matcher)
ctx.should_return(message.extract_plain_text())
Expand Down
Loading