From 3151672cfe8c6b885ccf9971b1b37db5f062b196 Mon Sep 17 00:00:00 2001 From: Rapptz Date: Sat, 10 Apr 2021 20:51:06 -0400 Subject: [PATCH] [commands] Refactor typing evaluation to not use get_type_hints get_type_hints had a few issues: 1. It would convert = None default parameters to Optional 2. It would not allow values as type annotations 3. It would not implicitly convert some string literals as ForwardRef In Python 3.9 `list['Foo']` does not convert into `list[ForwardRef('Foo')]` even though `typing.List` does this behaviour. In order to streamline it, evaluation had to be rewritten manually to support our usecases. This patch also flattens nested typing.Literal which was not done until Python 3.9.2. --- discord/ext/commands/core.py | 161 ++++++++++++++++++++++------------- 1 file changed, 103 insertions(+), 58 deletions(-) diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py index 9bd8f4f814ed..ab8c52c1b249 100644 --- a/discord/ext/commands/core.py +++ b/discord/ext/commands/core.py @@ -22,10 +22,20 @@ DEALINGS IN THE SOFTWARE. """ +from typing import ( + Any, + Dict, + ForwardRef, + Iterable, + Literal, + Tuple, + Union, + get_args as get_typing_args, + get_origin as get_typing_origin, +) import asyncio import functools import inspect -import typing import datetime import sys @@ -64,6 +74,83 @@ 'bot_has_guild_permissions' ) +PY_310 = sys.version_info >= (3, 10) + +def flatten_literal_params(parameters: Iterable[Any]) -> Tuple[Any, ...]: + params = [] + literal_cls = type(Literal[0]) + for p in parameters: + if isinstance(p, literal_cls): + params.extend(p.__args__) + else: + params.append(p) + return tuple(params) + +def _evaluate_annotation(tp: Any, globals: Dict[str, Any], cache: Dict[str, Any] = {}, *, implicit_str=True): + if isinstance(tp, ForwardRef): + tp = tp.__forward_arg__ + # ForwardRefs always evaluate their internals + implicit_str = True + + if implicit_str and isinstance(tp, str): + if tp in cache: + return cache[tp] + evaluated = eval(tp, globals) + cache[tp] = evaluated + return _evaluate_annotation(evaluated, globals, cache) + + if hasattr(tp, '__args__'): + implicit_str = True + args = tp.__args__ + if tp.__origin__ is Literal: + if not PY_310: + args = flatten_literal_params(tp.__args__) + implicit_str = False + + evaluated_args = tuple( + _evaluate_annotation(arg, globals, cache, implicit_str=implicit_str) for arg in args + ) + + if evaluated_args == args: + return tp + + try: + return tp.copy_with(evaluated_args) + except AttributeError: + return tp.__origin__[evaluated_args] + + return tp + +def resolve_annotation(annotation: Any, globalns: Dict[str, Any], cache: Dict[str, Any] = {}) -> Any: + if annotation is None: + return type(None) + if isinstance(annotation, str): + annotation = ForwardRef(annotation) + return _evaluate_annotation(annotation, globalns, cache) + +def get_signature_parameters(function) -> Dict[str, inspect.Parameter]: + globalns = function.__globals__ + signature = inspect.signature(function) + params = {} + cache: Dict[str, Any] = {} + for name, parameter in signature.parameters.items(): + annotation = parameter.annotation + if annotation is parameter.empty: + params[name] = parameter + continue + if annotation is None: + params[name] = parameter.replace(annotation=type(None)) + continue + + annotation = _evaluate_annotation(annotation, globalns, cache) + if annotation is converters.Greedy: + raise TypeError('Unparameterized Greedy[...] is disallowed in signature.') + + params[name] = parameter.replace(annotation=annotation) + + return params + + def wrap_callback(coro): @functools.wraps(coro) async def wrapped(*args, **kwargs): @@ -300,40 +387,7 @@ def callback(self): def callback(self, function): self._callback = function self.module = function.__module__ - - signature = inspect.signature(function) - self.params = signature.parameters.copy() - - # see: https://bugs.python.org/issue41341 - resolve = self._recursive_resolve if sys.version_info < (3, 9) else self._return_resolved - - try: - type_hints = {k: resolve(v) for k, v in typing.get_type_hints(function).items()} - except NameError as e: - raise NameError(f'unresolved forward reference: {e.args[0]}') from None - - for key, value in self.params.items(): - # coalesce the forward references - if key in type_hints: - self.params[key] = value = value.replace(annotation=type_hints[key]) - - # fail early for when someone passes an unparameterized Greedy type - if value.annotation is converters.Greedy: - raise TypeError('Unparameterized Greedy[...] is disallowed in signature.') - - def _return_resolved(self, type, **kwargs): - return type - - def _recursive_resolve(self, type, *, globals=None): - if not isinstance(type, typing.ForwardRef): - return type - - resolved = eval(type.__forward_arg__, globals) - args = typing.get_args(resolved) - for index, arg in enumerate(args): - inner_resolve_result = self._recursive_resolve(arg, globals=globals) - resolved[index] = inner_resolve_result - return resolved + self.params = get_signature_parameters(function) def add_check(self, func): """Adds a check to the command. @@ -493,12 +547,12 @@ async def _actual_conversion(self, ctx, converter, argument, param): raise BadArgument(f'Converting to "{name}" failed for parameter "{param.name}".') from exc async def do_conversion(self, ctx, converter, argument, param): - origin = typing.get_origin(converter) + origin = get_typing_origin(converter) - if origin is typing.Union: + if origin is Union: errors = [] _NoneType = type(None) - for conv in typing.get_args(converter): + for conv in get_typing_args(converter): # if we got to this part in the code, then the previous conversions have failed # so we should just undo the view, return the default, and allow parsing to continue # with the other parameters @@ -514,13 +568,12 @@ async def do_conversion(self, ctx, converter, argument, param): return value # if we're here, then we failed all the converters - raise BadUnionArgument(param, typing.get_args(converter), errors) + raise BadUnionArgument(param, get_typing_args(converter), errors) - if origin is typing.Literal: + if origin is Literal: errors = [] conversions = {} - literal_args = tuple(self._flattened_typing_literal_args(converter)) - for literal in literal_args: + for literal in converter.__args__: literal_type = type(literal) try: value = conversions[literal_type] @@ -538,7 +591,7 @@ async def do_conversion(self, ctx, converter, argument, param): return value # if we're here, then we failed to match all the literals - raise BadLiteralArgument(param, literal_args, errors) + raise BadLiteralArgument(param, converter.__args__, errors) return await self._actual_conversion(ctx, converter, argument, param) @@ -1021,14 +1074,7 @@ def short_doc(self): return '' def _is_typing_optional(self, annotation): - return typing.get_origin(annotation) is typing.Union and typing.get_args(annotation)[-1] is type(None) - - def _flattened_typing_literal_args(self, annotation): - for literal in typing.get_args(annotation): - if typing.get_origin(literal) is typing.Literal: - yield from self._flattened_typing_literal_args(literal) - else: - yield literal + return get_typing_origin(annotation) is Union and get_typing_args(annotation)[-1] is type(None) @property def signature(self): @@ -1048,17 +1094,16 @@ def signature(self): # for typing.Literal[...], typing.Optional[typing.Literal[...]], and Greedy[typing.Literal[...]], the # parameter signature is a literal list of it's values annotation = param.annotation.converter if greedy else param.annotation - origin = typing.get_origin(annotation) - if not greedy and origin is typing.Union: - union_args = typing.get_args(annotation) + origin = get_typing_origin(annotation) + if not greedy and origin is Union: + union_args = get_typing_args(annotation) optional = union_args[-1] is type(None) if optional: annotation = union_args[0] - origin = typing.get_origin(annotation) + origin = get_typing_origin(annotation) - if origin is typing.Literal: - name = '|'.join(f'"{v}"' if isinstance(v, str) else str(v) - for v in self._flattened_typing_literal_args(annotation)) + if origin is Literal: + name = '|'.join(f'"{v}"' if isinstance(v, str) else str(v) for v in annotation.__args__) if param.default is not param.empty: # We don't want None or '' to trigger the [name=value] case and instead it should # do [name] since [name=None] or [name=] are not exactly useful for the user.