Skip to content

Commit b25a94a

Browse files
introduce @automation.atomic decorator
1 parent 433da74 commit b25a94a

File tree

2 files changed

+39
-1
lines changed

2 files changed

+39
-1
lines changed

rosys/automation/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from .app_controls_ import AppButton
22
from .app_controls_ import AppControls as app_controls
3+
from .automation import atomic
34
from .automation_controls_ import AutomationControls as automation_controls
45
from .automator import Automator
56
from .parallelize import parallelize
@@ -10,6 +11,7 @@
1011
'Automator',
1112
'Schedule',
1213
'app_controls',
14+
'atomic',
1315
'automation_controls',
1416
'parallelize',
1517
]

rosys/automation/automation.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,41 @@
11
import asyncio
2+
import functools
23
import logging
34
from collections.abc import Callable, Coroutine, Generator
5+
from contextvars import ContextVar
46
from typing import Any
57

8+
# context for the currently running Automation; used by @atomic
9+
_CURRENT_AUTOMATION: ContextVar[object | None] = ContextVar('rosys_automation_current', default=None)
10+
11+
12+
def atomic(func: Callable):
13+
"""Decorator to make a function (sync or async) uninterruptible by pause until it exits."""
14+
if asyncio.iscoroutinefunction(func):
15+
@functools.wraps(func)
16+
async def _wrapped(*args, **kwargs):
17+
automation = _CURRENT_AUTOMATION.get()
18+
if automation is not None:
19+
automation._atomic_depth += 1 # type: ignore[attr-defined]
20+
try:
21+
return await func(*args, **kwargs)
22+
finally:
23+
automation._atomic_depth -= 1 # type: ignore[attr-defined]
24+
return await func(*args, **kwargs)
25+
return _wrapped
26+
else:
27+
@functools.wraps(func)
28+
def _wrapped(*args, **kwargs):
29+
automation = _CURRENT_AUTOMATION.get()
30+
if automation is not None:
31+
automation._atomic_depth += 1 # type: ignore[attr-defined]
32+
try:
33+
return func(*args, **kwargs)
34+
finally:
35+
automation._atomic_depth -= 1 # type: ignore[attr-defined]
36+
return func(*args, **kwargs)
37+
return _wrapped
38+
639

740
class Automation:
841
"""An automation wraps a coroutine and allows pausing and resuming it.
@@ -23,6 +56,7 @@ def __init__(self,
2356
self._can_run.set()
2457
self._stop = False
2558
self._is_waited = False
59+
self._atomic_depth = 0 # >0 while inside an @atomic section
2660

2761
@property
2862
def is_running(self) -> bool:
@@ -40,6 +74,7 @@ async def run(self) -> Any | None:
4074
return await self
4175

4276
def __await__(self) -> Generator[Any, None, Any | None]:
77+
token = _CURRENT_AUTOMATION.set(self) # bind this Automation instance into the task context
4378
try:
4479
self._is_waited = True
4580
coro_iter = self.coro.__await__()
@@ -48,7 +83,7 @@ def __await__(self) -> Generator[Any, None, Any | None]:
4883
message: Any = None
4984
while not self._stop:
5085
try:
51-
while not self._can_run.is_set() and not self._stop:
86+
while self._atomic_depth == 0 and not self._can_run.is_set() and not self._stop:
5287
yield from self._can_run.wait().__await__() # pylint: disable=no-member
5388
except BaseException as err:
5489
send, message = iter_throw, err
@@ -75,6 +110,7 @@ def __await__(self) -> Generator[Any, None, Any | None]:
75110
raise
76111
finally:
77112
self._is_waited = False
113+
_CURRENT_AUTOMATION.reset(token)
78114
return None
79115

80116
def pause(self) -> None:

0 commit comments

Comments
 (0)