1
1
import asyncio
2
+ import functools
2
3
import logging
3
4
from collections .abc import Callable , Coroutine , Generator
5
+ from contextvars import ContextVar
4
6
from typing import Any
5
7
8
+ # context for the currently running Automation; used by @atomic
9
+ _CURRENT_AUTOMATION : ContextVar [object | None ] = ContextVar ('rosys_automation_current' , default = None )
10
+
11
+
12
+ def atomic (func : Callable ):
13
+ """Decorator to make a function (sync or async) uninterruptible by pause until it exits."""
14
+ if asyncio .iscoroutinefunction (func ):
15
+ @functools .wraps (func )
16
+ async def _wrapped (* args , ** kwargs ):
17
+ automation = _CURRENT_AUTOMATION .get ()
18
+ if automation is not None :
19
+ automation ._atomic_depth += 1 # type: ignore[attr-defined]
20
+ try :
21
+ return await func (* args , ** kwargs )
22
+ finally :
23
+ automation ._atomic_depth -= 1 # type: ignore[attr-defined]
24
+ return await func (* args , ** kwargs )
25
+ return _wrapped
26
+ else :
27
+ @functools .wraps (func )
28
+ def _wrapped (* args , ** kwargs ):
29
+ automation = _CURRENT_AUTOMATION .get ()
30
+ if automation is not None :
31
+ automation ._atomic_depth += 1 # type: ignore[attr-defined]
32
+ try :
33
+ return func (* args , ** kwargs )
34
+ finally :
35
+ automation ._atomic_depth -= 1 # type: ignore[attr-defined]
36
+ return func (* args , ** kwargs )
37
+ return _wrapped
38
+
6
39
7
40
class Automation :
8
41
"""An automation wraps a coroutine and allows pausing and resuming it.
@@ -23,6 +56,7 @@ def __init__(self,
23
56
self ._can_run .set ()
24
57
self ._stop = False
25
58
self ._is_waited = False
59
+ self ._atomic_depth = 0 # >0 while inside an @atomic section
26
60
27
61
@property
28
62
def is_running (self ) -> bool :
@@ -40,6 +74,7 @@ async def run(self) -> Any | None:
40
74
return await self
41
75
42
76
def __await__ (self ) -> Generator [Any , None , Any | None ]:
77
+ token = _CURRENT_AUTOMATION .set (self ) # bind this Automation instance into the task context
43
78
try :
44
79
self ._is_waited = True
45
80
coro_iter = self .coro .__await__ ()
@@ -48,7 +83,7 @@ def __await__(self) -> Generator[Any, None, Any | None]:
48
83
message : Any = None
49
84
while not self ._stop :
50
85
try :
51
- while not self ._can_run .is_set () and not self ._stop :
86
+ while self . _atomic_depth == 0 and not self ._can_run .is_set () and not self ._stop :
52
87
yield from self ._can_run .wait ().__await__ () # pylint: disable=no-member
53
88
except BaseException as err :
54
89
send , message = iter_throw , err
@@ -75,6 +110,7 @@ def __await__(self) -> Generator[Any, None, Any | None]:
75
110
raise
76
111
finally :
77
112
self ._is_waited = False
113
+ _CURRENT_AUTOMATION .reset (token )
78
114
return None
79
115
80
116
def pause (self ) -> None :
0 commit comments