|
| 1 | +# https://github.com/python-trio/outcome/tree/6a3192f306ead4900a33fa8c47e5af5430e37692 |
| 2 | +# |
| 3 | +# The MIT License (MIT) |
| 4 | +# |
| 5 | +# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated |
| 6 | +# documentation files (the "Software"), to deal in the Software without restriction, including without limitation the |
| 7 | +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit |
| 8 | +# persons to whom the Software is furnished to do so, subject to the following conditions: |
| 9 | +# |
| 10 | +# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the |
| 11 | +# Software. |
| 12 | +# |
| 13 | +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE |
| 14 | +# WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR |
| 15 | +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR |
| 16 | +# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. |
| 17 | +import abc |
| 18 | +import typing as ta |
| 19 | +import dataclasses as dc |
| 20 | + |
| 21 | + |
| 22 | +ValueT = ta.TypeVar('ValueT', covariant=True) |
| 23 | +ResultT = ta.TypeVar('ResultT') |
| 24 | +ArgsT = ta.ParamSpec('ArgsT') |
| 25 | + |
| 26 | + |
| 27 | +## |
| 28 | + |
| 29 | + |
| 30 | +class AlreadyUsedError(RuntimeError): |
| 31 | + """An Outcome can only be unwrapped once.""" |
| 32 | + |
| 33 | + |
| 34 | +def remove_tb_frames(exc: BaseException, n: int) -> BaseException: |
| 35 | + tb = exc.__traceback__ |
| 36 | + for _ in range(n): |
| 37 | + assert tb is not None |
| 38 | + tb = tb.tb_next |
| 39 | + return exc.with_traceback(tb) |
| 40 | + |
| 41 | + |
| 42 | +## |
| 43 | + |
| 44 | + |
| 45 | +@ta.overload |
| 46 | +def capture( |
| 47 | + # NoReturn = raises exception, so we should get an error. |
| 48 | + sync_fn: ta.Callable[ArgsT, ta.NoReturn], |
| 49 | + *args: ArgsT.args, |
| 50 | + **kwargs: ArgsT.kwargs, |
| 51 | +) -> 'Error': |
| 52 | + ... |
| 53 | + |
| 54 | + |
| 55 | +@ta.overload |
| 56 | +def capture( |
| 57 | + sync_fn: ta.Callable[ArgsT, ResultT], |
| 58 | + *args: ArgsT.args, |
| 59 | + **kwargs: ArgsT.kwargs, |
| 60 | +) -> ta.Union['Value[ResultT]', 'Error']: |
| 61 | + ... |
| 62 | + |
| 63 | + |
| 64 | +def capture( |
| 65 | + sync_fn: ta.Callable[ArgsT, ResultT], |
| 66 | + *args: ArgsT.args, |
| 67 | + **kwargs: ArgsT.kwargs, |
| 68 | +) -> ta.Union['Value[ResultT]', 'Error']: |
| 69 | + """ |
| 70 | + Run ``sync_fn(*args, **kwargs)`` and capture the result. |
| 71 | +
|
| 72 | + Returns: |
| 73 | + Either a :class:`Value` or :class:`Error` as appropriate. |
| 74 | + """ |
| 75 | + |
| 76 | + try: |
| 77 | + return Value(sync_fn(*args, **kwargs)) |
| 78 | + except BaseException as exc: |
| 79 | + exc = remove_tb_frames(exc, 1) |
| 80 | + return Error(exc) |
| 81 | + |
| 82 | + |
| 83 | +# |
| 84 | + |
| 85 | + |
| 86 | +@ta.overload |
| 87 | +async def acapture( |
| 88 | + async_fn: ta.Callable[ArgsT, ta.Awaitable[ta.NoReturn]], |
| 89 | + *args: ArgsT.args, |
| 90 | + **kwargs: ArgsT.kwargs, |
| 91 | +) -> 'Error': |
| 92 | + ... |
| 93 | + |
| 94 | + |
| 95 | +@ta.overload |
| 96 | +async def acapture( |
| 97 | + async_fn: ta.Callable[ArgsT, ta.Awaitable[ResultT]], |
| 98 | + *args: ArgsT.args, |
| 99 | + **kwargs: ArgsT.kwargs, |
| 100 | +) -> ta.Union['Value[ResultT]', 'Error']: |
| 101 | + ... |
| 102 | + |
| 103 | + |
| 104 | +async def acapture( |
| 105 | + async_fn: ta.Callable[ArgsT, ta.Awaitable[ResultT]], |
| 106 | + *args: ArgsT.args, |
| 107 | + **kwargs: ArgsT.kwargs, |
| 108 | +) -> ta.Union['Value[ResultT]', 'Error']: |
| 109 | + """ |
| 110 | + Run ``await async_fn(*args, **kwargs)`` and capture the result. |
| 111 | +
|
| 112 | + Returns: |
| 113 | + Either a :class:`Value` or :class:`Error` as appropriate. |
| 114 | + """ |
| 115 | + |
| 116 | + try: |
| 117 | + return Value(await async_fn(*args, **kwargs)) |
| 118 | + |
| 119 | + except BaseException as exc: |
| 120 | + exc = remove_tb_frames(exc, 1) |
| 121 | + return Error(exc) |
| 122 | + |
| 123 | + |
| 124 | +## |
| 125 | + |
| 126 | + |
| 127 | +@dc.dataclass(repr=False, init=False, slots=True, frozen=True, order=True) |
| 128 | +class Outcome(abc.ABC, ta.Generic[ValueT]): |
| 129 | + """ |
| 130 | + An abstract class representing the result of a Python computation. |
| 131 | +
|
| 132 | + This class has two concrete subclasses: :class:`Value` representing a value, and :class:`Error` representing an |
| 133 | + exception. |
| 134 | +
|
| 135 | + In addition to the methods described below, comparison operators on :class:`Value` and :class:`Error` objects |
| 136 | + (``==``, ``<``, etc.) check that the other object is also a :class:`Value` or :class:`Error` object respectively, |
| 137 | + and then compare the contained objects. |
| 138 | +
|
| 139 | + :class:`Outcome` objects are hashable if the contained objects are hashable. |
| 140 | + """ |
| 141 | + |
| 142 | + _unwrapped: bool = dc.field(default=False, compare=False, init=False) |
| 143 | + |
| 144 | + def _set_unwrapped(self) -> None: |
| 145 | + if self._unwrapped: |
| 146 | + raise AlreadyUsedError |
| 147 | + object.__setattr__(self, '_unwrapped', True) |
| 148 | + |
| 149 | + @abc.abstractmethod |
| 150 | + def unwrap(self) -> ValueT: |
| 151 | + """ |
| 152 | + Return or raise the contained value or exception. |
| 153 | +
|
| 154 | + These two lines of code are equivalent:: |
| 155 | +
|
| 156 | + x = fn(*args) |
| 157 | + x = outcome.capture(fn, *args).unwrap() |
| 158 | + """ |
| 159 | + |
| 160 | + @abc.abstractmethod |
| 161 | + def send(self, gen: ta.Generator[ResultT, ValueT, object]) -> ResultT: |
| 162 | + """ |
| 163 | + Send or throw the contained value or exception into the given generator object. |
| 164 | +
|
| 165 | + Args: |
| 166 | + gen: A generator object supporting ``.send()`` and ``.throw()`` methods. |
| 167 | + """ |
| 168 | + |
| 169 | + @abc.abstractmethod |
| 170 | + async def asend(self, agen: ta.AsyncGenerator[ResultT, ValueT]) -> ResultT: |
| 171 | + """ |
| 172 | + Send or throw the contained value or exception into the given async generator object. |
| 173 | +
|
| 174 | + Args: |
| 175 | + agen: An async generator object supporting ``.asend()`` and ``.athrow()`` methods. |
| 176 | + """ |
| 177 | + |
| 178 | + |
| 179 | +@ta.final |
| 180 | +@dc.dataclass(frozen=True, repr=False, slots=True, order=True) |
| 181 | +class Value(Outcome[ValueT], ta.Generic[ValueT]): |
| 182 | + """Concrete :class:`Outcome` subclass representing a regular value.""" |
| 183 | + |
| 184 | + value: ValueT |
| 185 | + |
| 186 | + def __repr__(self) -> str: |
| 187 | + return f'Value({self.value!r})' |
| 188 | + |
| 189 | + def unwrap(self) -> ValueT: |
| 190 | + self._set_unwrapped() |
| 191 | + return self.value |
| 192 | + |
| 193 | + def send(self, gen: ta.Generator[ResultT, ValueT, object]) -> ResultT: |
| 194 | + self._set_unwrapped() |
| 195 | + return gen.send(self.value) |
| 196 | + |
| 197 | + async def asend(self, agen: ta.AsyncGenerator[ResultT, ValueT]) -> ResultT: |
| 198 | + self._set_unwrapped() |
| 199 | + return await agen.asend(self.value) |
| 200 | + |
| 201 | + |
| 202 | +@ta.final |
| 203 | +@dc.dataclass(frozen=True, repr=False, slots=True, order=True) |
| 204 | +class Error(Outcome[ta.NoReturn]): |
| 205 | + """Concrete :class:`Outcome` subclass representing a raised exception.""" |
| 206 | + |
| 207 | + error: BaseException |
| 208 | + |
| 209 | + def __post_init__(self) -> None: |
| 210 | + if not isinstance(self.error, BaseException): |
| 211 | + raise TypeError(self.error) |
| 212 | + |
| 213 | + def __repr__(self) -> str: |
| 214 | + return f'Error({self.error!r})' |
| 215 | + |
| 216 | + def unwrap(self) -> ta.NoReturn: |
| 217 | + self._set_unwrapped() |
| 218 | + |
| 219 | + # Tracebacks show the 'raise' line below out of context, so let's give this variable a name that makes sense out |
| 220 | + # of context. |
| 221 | + captured_error = self.error |
| 222 | + |
| 223 | + try: |
| 224 | + raise captured_error |
| 225 | + |
| 226 | + finally: |
| 227 | + # We want to avoid creating a reference cycle here. Python does collect cycles just fine, so it wouldn't be |
| 228 | + # the end of the world if we did create a cycle, but the cyclic garbage collector adds latency to Python |
| 229 | + # programs, and the more cycles you create, the more often it runs, so it's nicer to avoid creating them in |
| 230 | + # the first place. For more details see: |
| 231 | + # |
| 232 | + # https://github.com/python-trio/trio/issues/1770 |
| 233 | + # |
| 234 | + # In particular, by deleting this local variables from the 'unwrap' methods frame, we avoid the |
| 235 | + # 'captured_error' object's __traceback__ from indirectly referencing 'captured_error'. |
| 236 | + del captured_error, self |
| 237 | + |
| 238 | + def send(self, gen: ta.Generator[ResultT, ta.NoReturn, object]) -> ResultT: |
| 239 | + self._set_unwrapped() |
| 240 | + return gen.throw(self.error) |
| 241 | + |
| 242 | + async def asend(self, agen: ta.AsyncGenerator[ResultT, ta.NoReturn]) -> ResultT: |
| 243 | + self._set_unwrapped() |
| 244 | + return await agen.athrow(self.error) |
| 245 | + |
| 246 | + |
| 247 | +# A convenience alias to a union of both results, allowing exhaustiveness checking. |
| 248 | +Maybe = Value[ValueT] | Error |
0 commit comments