@@ -30,44 +30,54 @@ class ConfigRunner:
3030 def __init__ (self ):
3131 """Initialize config runner instance."""
3232 self ._loop = asyncio .get_event_loop ()
33- self ._task : asyncio .Task | None = None
3433 self ._event = asyncio .Event ()
34+ self ._world_tasks : dict [str , asyncio .Task ] = {}
3535 self ._spec : ServeConfig = None
3636 self ._event .set () # initially no configure running
3737 self ._curr_worlds_to_configure : set [str ] = set ()
38- self ._cancel_cur_cfg = False
3938 self ._worlds_to_add : list [WorldInfo ]
4039 self ._worlds_to_remove : list [WorldInfo ]
4140 self ._world_infos : dict [str , WorldInfo ] = {}
4241
43- def handle_new_spec (self , spec : ServeConfig ) -> None :
42+ async def handle_new_spec (self , spec : ServeConfig ) -> None :
4443 """Handle new spec."""
45- self ._cancel_cur_cfg = self ._should_cancel_current (spec )
44+ new_worlds_to_configure = ServeConfig .get_worlds_to_configure (self ._spec , spec )
45+ worlds_to_cancel = new_worlds_to_configure & self ._curr_worlds_to_configure
4646 self ._spec = spec
47- new_world_infos = self ._build_world_infos ()
4847
48+ if len (worlds_to_cancel ):
49+ # since _worlds_to_add are used in pipeline, we need to update them
50+ # by removing the ones that will be cancelled.
51+ self ._worlds_to_add = [
52+ world_info
53+ for world_info in self ._worlds_to_add
54+ if world_info .name not in worlds_to_cancel
55+ ]
56+ await self ._cancel_world_configuration (worlds_to_cancel )
57+
58+ # after cancelling affected worlds, wait for configuration to finish
59+ # then process next config.
60+ await self ._event .wait ()
61+
62+ new_world_infos = self ._build_world_infos ()
4963 new = new_world_infos .keys ()
5064 cur = self ._world_infos .keys ()
5165
5266 self ._worlds_to_add = self .get_words_to_add (new_world_infos , new , cur )
5367 self ._worlds_to_remove = self .get_words_to_remove (self ._world_infos , new , cur )
5468
55- def _should_cancel_current (self , spec : ServeConfig ) -> bool :
56- """Decide if current configuration should be cancelled."""
57- if self ._spec is None :
58- return False
69+ self ._curr_worlds_to_configure = new - cur
5970
60- new_worlds_to_configure = ServeConfig .get_worlds_to_configure (self ._spec , spec )
71+ # block configuration until current config is processed
72+ self ._event .clear ()
6173
62- # cancel if the new config affects worlds currently being configured
63- # TODO: if there's a overlap between new worlds and curr worlds we cancel
64- # current configuration. This needs to be fixed, to cancel only the worlds that
65- # are affected (eg new_worlds & curr_worlds)
66- return not new_worlds_to_configure .isdisjoint (self ._curr_worlds_to_configure )
74+ def unblock_next_config (self ) -> None :
75+ """Set task event and unblock next config process."""
76+ self ._event .set ()
6777
68- def set_worlds_to_configure (self , world_names : set [ str ] ) -> None :
69- """Set the world names currently being configured ."""
70- self ._curr_worlds_to_configure = world_names
78+ def reset_curr_worlds_to_configure (self ) -> None :
79+ """Reset current worlds to configure ."""
80+ self ._curr_worlds_to_configure = set ()
7181
7282 def set_world_infos (self ) -> None :
7383 """Set new world infos."""
@@ -98,23 +108,28 @@ def get_words_to_remove(
98108 """Return a list of world infos to remove."""
99109 return [world_infos [name ] for name in cur - new ]
100110
101- async def schedule (self , coro_factory : Callable [[], Awaitable [None ]]):
102- """Cancel any in-progress configure and schedule a new one."""
103- # wait for current to finish if we do not want to cancel
104- if not self ._cancel_cur_cfg :
105- await self ._event .wait ()
106-
107- # cancel current if running
108- if self ._task and not self ._task .done ():
109- self ._task .cancel ()
111+ async def _cancel_world_configuration (self , world_names : set [str ]):
112+ """Cancel only worlds that are impacted by new spec."""
113+ coroutines = [self ._cancel_world (w ) for w in world_names ]
114+ await asyncio .gather (* coroutines , return_exceptions = True )
115+
116+ def schedule_world_cfg (
117+ self , world_info : WorldInfo , coro_factory : Callable [[], Awaitable [None ]]
118+ ):
119+ """Schedule configuration for a single world."""
120+ task = self ._loop .create_task (self ._run_world (world_info , coro_factory ))
121+ self ._world_tasks [world_info .name ] = task
122+ return task
123+
124+ async def _cancel_world (self , world_name : str ):
125+ """Cancel an in-progress world config task."""
126+ task = self ._world_tasks .pop (world_name , None )
127+ if task and not task .done ():
128+ task .cancel ()
110129 try :
111- await self . _task
130+ await task
112131 except asyncio .CancelledError :
113- pass
114-
115- # block again for new run
116- self ._event .clear ()
117- self ._task = self ._loop .create_task (self ._run (coro_factory ))
132+ raise
118133
119134 def _build_world_infos (self ) -> dict [str , WorldInfo ]:
120135 world_infos : dict [str , WorldInfo ] = {}
@@ -170,13 +185,13 @@ def _build_world_infos(self) -> dict[str, WorldInfo]:
170185
171186 return world_infos
172187
173- async def _run (self , coro_factory : Callable [[], Awaitable [None ]]):
174- """Run coroutine factory."""
188+ async def _run_world (
189+ self , world_info : WorldInfo , coro_factory : Callable [[], Awaitable [None ]]
190+ ):
191+ """Run and cleanup world configuration."""
175192 try :
176- await coro_factory ()
193+ await coro_factory (world_info )
177194 except asyncio .CancelledError :
178- pass
195+ raise
179196 finally :
180- # reset class attributes and events
181- self ._event .set ()
182- self ._curr_worlds_to_configure = set ()
197+ self ._world_tasks .pop (world_info .name , None )
0 commit comments