Skip to content

Commit f335c89

Browse files
committed
feat: handle individual world configuration
Since having one config for the entire configuration couldn't handle partial worlds update, we implemented per world configuration tasks. Config runner takes care of scheduling and canceling the tasks based on some diffs between old and new config. Also, config runner will orchestrate processing configs based on what it needs to cancel. Updated the code with proper error handlers for cancelling tasks so it properly propagates to parent task. With this, world configuration can be tracked individually, making easier to implement retry on failure in the future.
1 parent d90516b commit f335c89

File tree

4 files changed

+106
-52
lines changed

4 files changed

+106
-52
lines changed

infscale/configs/job.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,11 +157,16 @@ def get_worlds_to_configure(
157157
"""Compare two specs and return new and updated worlds."""
158158
helper = ServeConfigHelper()
159159

160-
curr_worlds = helper._get_worlds(curr_spec)
161160
new_worlds = helper._get_worlds(new_spec)
161+
new_world_names = set(new_worlds.keys())
162+
163+
# if current spec is not available, return worlds from
164+
# the new spec.
165+
if curr_spec is None:
166+
return new_world_names
162167

168+
curr_worlds = helper._get_worlds(curr_spec)
163169
curr_world_names = set(curr_worlds.keys())
164-
new_world_names = set(new_worlds.keys())
165170

166171
deploy_worlds = new_world_names - curr_world_names
167172

infscale/execution/config_runner.py

Lines changed: 55 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -30,44 +30,54 @@ class ConfigRunner:
3030
def __init__(self):
3131
"""Initialize config runner instance."""
3232
self._loop = asyncio.get_event_loop()
33-
self._task: asyncio.Task | None = None
3433
self._event = asyncio.Event()
34+
self._world_tasks: dict[str, asyncio.Task] = {}
3535
self._spec: ServeConfig = None
3636
self._event.set() # initially no configure running
3737
self._curr_worlds_to_configure: set[str] = set()
38-
self._cancel_cur_cfg = False
3938
self._worlds_to_add: list[WorldInfo]
4039
self._worlds_to_remove: list[WorldInfo]
4140
self._world_infos: dict[str, WorldInfo] = {}
4241

43-
def handle_new_spec(self, spec: ServeConfig) -> None:
42+
async def handle_new_spec(self, spec: ServeConfig) -> None:
4443
"""Handle new spec."""
45-
self._cancel_cur_cfg = self._should_cancel_current(spec)
44+
new_worlds_to_configure = ServeConfig.get_worlds_to_configure(self._spec, spec)
45+
worlds_to_cancel = new_worlds_to_configure & self._curr_worlds_to_configure
4646
self._spec = spec
47-
new_world_infos = self._build_world_infos()
4847

48+
if len(worlds_to_cancel):
49+
# since _worlds_to_add are used in pipeline, we need to update them
50+
# by removing the ones that will be cancelled.
51+
self._worlds_to_add = [
52+
world_info
53+
for world_info in self._worlds_to_add
54+
if world_info.name not in worlds_to_cancel
55+
]
56+
await self._cancel_world_configuration(worlds_to_cancel)
57+
58+
# after cancelling affected worlds, wait for configuration to finish
59+
# then process next config.
60+
await self._event.wait()
61+
62+
new_world_infos = self._build_world_infos()
4963
new = new_world_infos.keys()
5064
cur = self._world_infos.keys()
5165

5266
self._worlds_to_add = self.get_words_to_add(new_world_infos, new, cur)
5367
self._worlds_to_remove = self.get_words_to_remove(self._world_infos, new, cur)
5468

55-
def _should_cancel_current(self, spec: ServeConfig) -> bool:
56-
"""Decide if current configuration should be cancelled."""
57-
if self._spec is None:
58-
return False
69+
self._curr_worlds_to_configure = new - cur
5970

60-
new_worlds_to_configure = ServeConfig.get_worlds_to_configure(self._spec, spec)
71+
# block configuration until current config is processed
72+
self._event.clear()
6173

62-
# cancel if the new config affects worlds currently being configured
63-
# TODO: if there's a overlap between new worlds and curr worlds we cancel
64-
# current configuration. This needs to be fixed, to cancel only the worlds that
65-
# are affected (eg new_worlds & curr_worlds)
66-
return not new_worlds_to_configure.isdisjoint(self._curr_worlds_to_configure)
74+
def unblock_next_config(self) -> None:
75+
"""Set task event and unblock next config process."""
76+
self._event.set()
6777

68-
def set_worlds_to_configure(self, world_names: set[str]) -> None:
69-
"""Set the world names currently being configured."""
70-
self._curr_worlds_to_configure = world_names
78+
def reset_curr_worlds_to_configure(self) -> None:
79+
"""Reset current worlds to configure."""
80+
self._curr_worlds_to_configure = set()
7181

7282
def set_world_infos(self) -> None:
7383
"""Set new world infos."""
@@ -98,23 +108,28 @@ def get_words_to_remove(
98108
"""Return a list of world infos to remove."""
99109
return [world_infos[name] for name in cur - new]
100110

101-
async def schedule(self, coro_factory: Callable[[], Awaitable[None]]):
102-
"""Cancel any in-progress configure and schedule a new one."""
103-
# wait for current to finish if we do not want to cancel
104-
if not self._cancel_cur_cfg:
105-
await self._event.wait()
106-
107-
# cancel current if running
108-
if self._task and not self._task.done():
109-
self._task.cancel()
111+
async def _cancel_world_configuration(self, world_names: set[str]):
112+
"""Cancel only worlds that are impacted by new spec."""
113+
coroutines = [self._cancel_world(w) for w in world_names]
114+
await asyncio.gather(*coroutines, return_exceptions=True)
115+
116+
def schedule_world_cfg(
117+
self, world_info: WorldInfo, coro_factory: Callable[[], Awaitable[None]]
118+
):
119+
"""Schedule configuration for a single world."""
120+
task = self._loop.create_task(self._run_world(world_info, coro_factory))
121+
self._world_tasks[world_info.name] = task
122+
return task
123+
124+
async def _cancel_world(self, world_name: str):
125+
"""Cancel an in-progress world config task."""
126+
task = self._world_tasks.pop(world_name, None)
127+
if task and not task.done():
128+
task.cancel()
110129
try:
111-
await self._task
130+
await task
112131
except asyncio.CancelledError:
113-
pass
114-
115-
# block again for new run
116-
self._event.clear()
117-
self._task = self._loop.create_task(self._run(coro_factory))
132+
raise
118133

119134
def _build_world_infos(self) -> dict[str, WorldInfo]:
120135
world_infos: dict[str, WorldInfo] = {}
@@ -170,13 +185,13 @@ def _build_world_infos(self) -> dict[str, WorldInfo]:
170185

171186
return world_infos
172187

173-
async def _run(self, coro_factory: Callable[[], Awaitable[None]]):
174-
"""Run coroutine factory."""
188+
async def _run_world(
189+
self, world_info: WorldInfo, coro_factory: Callable[[], Awaitable[None]]
190+
):
191+
"""Run and cleanup world configuration."""
175192
try:
176-
await coro_factory()
193+
await coro_factory(world_info)
177194
except asyncio.CancelledError:
178-
pass
195+
raise
179196
finally:
180-
# reset class attributes and events
181-
self._event.set()
182-
self._curr_worlds_to_configure = set()
197+
self._world_tasks.pop(world_info.name, None)

infscale/execution/control.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,10 +162,32 @@ async def setup(self) -> None:
162162
if self.rank == 0:
163163
self._server_task = asyncio.create_task(self._setup_server(setup_done))
164164
else:
165-
_ = asyncio.create_task(self._setup_client(setup_done))
165+
client_task = asyncio.create_task(self._setup_client(setup_done))
166166

167167
# wait until setting up either server or client is done
168-
await setup_done.wait()
168+
try:
169+
await setup_done.wait()
170+
except asyncio.CancelledError as e:
171+
# logger.warning(f"[{self.rank}] channel setup cancelled")
172+
# since both _setup_server and _setup_client are spawned as separate tasks
173+
# and the setup itself is a task, we need to handle parent task cancellation
174+
# on the awaited line, since cancellation only propagates through awaited calls
175+
# here, await setup_done.wait() is the propagation point from parent task to child tasks
176+
# so we need to cancel child tasks whenever CancelledError is received
177+
if self._server_task and not self._server_task.done():
178+
self._server_task.cancel()
179+
try:
180+
await self._server_task
181+
except asyncio.CancelledError:
182+
pass
183+
184+
if client_task and not client_task.done():
185+
client_task.cancel()
186+
try:
187+
await client_task
188+
except asyncio.CancelledError:
189+
pass
190+
raise
169191

170192
def cleanup(self) -> None:
171193
if self._server_task is not None:

infscale/execution/pipeline.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,8 @@ async def _configure_multiworld(self, world_info: WorldInfo) -> None:
104104
port=port,
105105
device=self.device,
106106
)
107+
except asyncio.CancelledError:
108+
logger.warning(f"multiworld configuration cancelled for {world_info.name}")
107109
except Exception as e:
108110
logger.error(f"failed to initialize a multiworld {name}: {e}")
109111
condition = self._status != WorkerStatus.UPDATING
@@ -130,9 +132,12 @@ def _set_n_send_worker_status(self, status: WorkerStatus) -> None:
130132
self.wcomm.send(msg)
131133

132134
async def _configure_control_channel(self, world_info: WorldInfo) -> None:
133-
await world_info.channel.setup()
135+
try:
136+
await world_info.channel.setup()
134137

135-
await world_info.channel.wait_readiness()
138+
await world_info.channel.wait_readiness()
139+
except asyncio.CancelledError:
140+
logger.warning(f"channel configuration cancelled for {world_info.name} ")
136141

137142
def _reset_multiworld(self, world_info: WorldInfo) -> None:
138143
self.world_manager.remove_world(world_info.multiworld_name)
@@ -180,7 +185,9 @@ async def _configure(self) -> None:
180185
tasks = []
181186
# 1. set up control channel
182187
for world_info in self.config_runner._worlds_to_add:
183-
task = self._configure_control_channel(world_info)
188+
task = self.config_runner.schedule_world_cfg(
189+
world_info, self._configure_control_channel
190+
)
184191
tasks.append(task)
185192

186193
# TODO: this doesn't handle partial success
@@ -190,7 +197,9 @@ async def _configure(self) -> None:
190197
tasks = []
191198
# 2. set up multiworld
192199
for world_info in self.config_runner._worlds_to_add:
193-
task = self._configure_multiworld(world_info)
200+
task = self.config_runner.schedule_world_cfg(
201+
world_info, self._configure_multiworld
202+
)
194203
tasks.append(task)
195204

196205
# TODO: this doesn't handle partial success
@@ -220,6 +229,9 @@ async def _configure(self) -> None:
220229

221230
worker_status = WorkerStatus.RUNNING if is_first_run else WorkerStatus.UPDATED
222231

232+
# config is done, do cleanup in config runner
233+
self.config_runner.reset_curr_worlds_to_configure()
234+
self.config_runner.unblock_next_config()
223235
self._set_n_send_worker_status(worker_status)
224236

225237
self.cfg_event.set()
@@ -423,14 +435,14 @@ async def _handle_config(self, spec: ServeConfig) -> None:
423435

424436
await self._cleanup_recovered_worlds()
425437

426-
self.config_runner.handle_new_spec(spec)
427-
428438
self._inspector.configure(self.spec)
429439

430440
self._initialize_once()
431-
432441
# (re)configure the pipeline
433-
await self.config_runner.schedule(self._configure)
442+
await self.config_runner.handle_new_spec(spec)
443+
# run configure as a separate task since we need to unblock receiving
444+
# a new config to be processed when current configuration is finished
445+
self._configure_task = asyncio.create_task(self._configure())
434446

435447
def _build_world_infos(self) -> dict[str, WorldInfo]:
436448
world_infos: dict[str, WorldInfo] = {}

0 commit comments

Comments
 (0)