12
12
from _pytest .outcomes import Skipped # noqa
13
13
from _pytest .outcomes import XFailed # noqa
14
14
15
+ from ..... import check
15
16
from ..... import lang
16
17
from .....diag import pydevd as pdu
17
18
from .._registry import register
18
- from .backends .asyncio import AsyncioAsyncsBackend
19
- from .backends .base import AsyncsBackend
20
- from .backends .trio import TrioAsyncsBackend
21
- from .backends .trio_asyncio import TrioAsyncioAsyncsBackend
19
+ from .backends import ASYNC_BACKENDS
20
+ from .backends import AsyncsBackend
22
21
from .consts import ASYNCS_MARK
23
- from .consts import KNOWN_BACKENDS
24
22
from .consts import PARAM_NAME
25
23
from .fixtures import CANARY
26
24
from .fixtures import AsyncsFixture
41
39
42
40
@register
43
41
class AsyncsPlugin :
44
- ASYNC_BACKENDS : ta .ClassVar [ta .Sequence [str ]] = [
45
- * [s for s in KNOWN_BACKENDS if lang .can_import (s )],
46
- ]
42
+ def __init__ (self , backends : ta .Collection [type [AsyncsBackend ]] | None = None ) -> None :
43
+ super ().__init__ ()
44
+
45
+ if backends is None :
46
+ backends = ASYNC_BACKENDS
47
+
48
+ bd : dict [str , AsyncsBackend ] = {}
49
+ for bc in backends :
50
+ be = bc ()
51
+ if not be .is_available ():
52
+ continue
53
+ bn = be .name
54
+ check .not_in (bn , bd )
55
+ bd [bn ] = be
56
+ self ._backends = bd
47
57
48
58
def pytest_cmdline_main (self , config ):
49
59
if (aio_plugin := sys .modules .get ('pytest_asyncio.plugin' )):
@@ -67,7 +77,7 @@ def pytest_generate_tests(self, metafunc):
67
77
if m .args :
68
78
bes = m .args
69
79
else :
70
- bes = self .ASYNC_BACKENDS
80
+ bes = list ( self ._backends )
71
81
else :
72
82
return
73
83
@@ -120,45 +130,13 @@ def pytest_runtest_call(self, item):
120
130
yield
121
131
return
122
132
123
- be = item .callspec .params [PARAM_NAME ]
133
+ bn = item .callspec .params [PARAM_NAME ]
134
+ be = self ._backends [bn ]
124
135
125
- beo : AsyncsBackend
126
- if be == 'asyncio' :
127
- beo = AsyncioAsyncsBackend ()
128
- elif be == 'trio' :
129
- beo = TrioAsyncsBackend ()
130
- elif be == 'trio_asyncio' :
131
- beo = TrioAsyncioAsyncsBackend ()
132
- else :
133
- raise ValueError (be )
134
-
135
- item .obj = self .test_runner_factory (beo , item )
136
+ item .obj = self .test_runner_factory (be , item )
136
137
137
138
yield
138
139
139
- # bes = [be for be in self.ASYNC_BACKENDS if item.get_closest_marker(be) is not None]
140
- # if len(bes) > 1 and set(bes) != {'trio', 'trio_asyncio'}:
141
- # raise Exception(f'{item.nodeid}: multiple async backends specified: {bes}')
142
- # elif is_async_function(item.obj) and not bes:
143
- # from _pytest.unittest import UnitTestCase # noqa
144
- # if isinstance(item.parent, UnitTestCase):
145
- # # unittest handles these itself.
146
- # pass
147
- # else:
148
- # raise Exception(f'{item.nodeid}: async def function and no async plugin specified')
149
- #
150
- # if 'trio_asyncio' in bes:
151
- # obj = item.obj
152
- #
153
- # @functools.wraps(obj)
154
- # @trai.with_trio_asyncio_loop(wait=True)
155
- # async def run(*args, **kwargs):
156
- # await trio_asyncio.aio_as_trio(obj)(*args, **kwargs)
157
- #
158
- # item.obj = run
159
- #
160
- # yield
161
-
162
140
def test_runner_factory (self , backend : AsyncsBackend , item , testfunc = None ):
163
141
if not testfunc :
164
142
testfunc = item .obj
0 commit comments