Skip to content

Commit d90516b

Browse files
committed
refactor: pipeline world infos
Refactored pipeline code and moved world info related stuff into config_runner. With this, config_runner will be responsible of computing worlds to add and to remove and it will be used for when we refactor the code to handle per world configuration.
1 parent 0112e8a commit d90516b

File tree

2 files changed

+118
-28
lines changed

2 files changed

+118
-28
lines changed

infscale/execution/config_runner.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
from typing import Awaitable, Callable
2121

2222
from infscale.configs.job import ServeConfig
23+
from infscale.execution.control import Channel as CtrlCh
24+
from infscale.execution.world import WorldInfo
2325

2426

2527
class ConfigRunner:
@@ -34,11 +36,21 @@ def __init__(self):
3436
self._event.set() # initially no configure running
3537
self._curr_worlds_to_configure: set[str] = set()
3638
self._cancel_cur_cfg = False
39+
self._worlds_to_add: list[WorldInfo]
40+
self._worlds_to_remove: list[WorldInfo]
41+
self._world_infos: dict[str, WorldInfo] = {}
3742

3843
def handle_new_spec(self, spec: ServeConfig) -> None:
3944
"""Handle new spec."""
4045
self._cancel_cur_cfg = self._should_cancel_current(spec)
4146
self._spec = spec
47+
new_world_infos = self._build_world_infos()
48+
49+
new = new_world_infos.keys()
50+
cur = self._world_infos.keys()
51+
52+
self._worlds_to_add = self.get_words_to_add(new_world_infos, new, cur)
53+
self._worlds_to_remove = self.get_words_to_remove(self._world_infos, new, cur)
4254

4355
def _should_cancel_current(self, spec: ServeConfig) -> bool:
4456
"""Decide if current configuration should be cancelled."""
@@ -57,6 +69,35 @@ def set_worlds_to_configure(self, world_names: set[str]) -> None:
5769
"""Set the world names currently being configured."""
5870
self._curr_worlds_to_configure = world_names
5971

72+
def set_world_infos(self) -> None:
73+
"""Set new world infos."""
74+
for world_info in self._worlds_to_add:
75+
self._world_infos[world_info.name] = world_info
76+
77+
def get_world_infos(self) -> dict[str, WorldInfo]:
78+
"Get world infos."
79+
return self._world_infos
80+
81+
def is_first_run(self) -> bool:
82+
"Return boolean if is first run or not."
83+
return not self._world_infos
84+
85+
def remove_world_info(self, world_name: str) -> None:
86+
"""Remove world info by name."""
87+
del self._world_infos[world_name]
88+
89+
def get_words_to_add(
90+
self, world_infos: list[WorldInfo], new: set[str], cur: set[str]
91+
) -> list[WorldInfo]:
92+
"""Return a list of world infos to add."""
93+
return [world_infos[name] for name in new - cur]
94+
95+
def get_words_to_remove(
96+
self, world_infos: list[WorldInfo], new: set[str], cur: set[str]
97+
) -> list[WorldInfo]:
98+
"""Return a list of world infos to remove."""
99+
return [world_infos[name] for name in cur - new]
100+
60101
async def schedule(self, coro_factory: Callable[[], Awaitable[None]]):
61102
"""Cancel any in-progress configure and schedule a new one."""
62103
# wait for current to finish if we do not want to cancel
@@ -75,6 +116,60 @@ async def schedule(self, coro_factory: Callable[[], Awaitable[None]]):
75116
self._event.clear()
76117
self._task = self._loop.create_task(self._run(coro_factory))
77118

119+
def _build_world_infos(self) -> dict[str, WorldInfo]:
120+
world_infos: dict[str, WorldInfo] = {}
121+
122+
my_id = self._spec.stage.id
123+
for k, v in self._spec.flow_graph.items():
124+
for cfg_world_info in v:
125+
# NOTE: no. of peers is always 1 for now
126+
assert len(cfg_world_info.peers) == 1
127+
128+
if my_id == k:
129+
my_rank = 0
130+
other_rank = 1
131+
other_id = cfg_world_info.peers[0]
132+
elif my_id in cfg_world_info.peers:
133+
# NOTE: this is always 1 for now
134+
my_rank = cfg_world_info.peers.index(my_id) + 1
135+
other_rank = 0
136+
other_id = k
137+
else:
138+
continue
139+
140+
name, backend, addr, data_port, ctrl_port, recover, conflict_count = (
141+
cfg_world_info.name,
142+
cfg_world_info.backend,
143+
cfg_world_info.addr,
144+
cfg_world_info.data_port,
145+
cfg_world_info.ctrl_port,
146+
cfg_world_info.recover,
147+
cfg_world_info.conflict_count,
148+
)
149+
150+
world_size = len(cfg_world_info.peers) + 1
151+
ctrl_ch = CtrlCh(my_rank, world_size, addr, ctrl_port)
152+
153+
data = {
154+
"name": name,
155+
"size": world_size,
156+
"addr": addr,
157+
"port": data_port,
158+
"backend": backend,
159+
"channel": ctrl_ch,
160+
"my_id": my_id,
161+
"me": my_rank,
162+
"other_id": other_id,
163+
"other": other_rank,
164+
"recover": recover,
165+
"conflict_count": conflict_count,
166+
"multiworld_name": f"{name}-{conflict_count}",
167+
}
168+
world_info = WorldInfo(**data)
169+
world_infos[name] = world_info
170+
171+
return world_infos
172+
78173
async def _run(self, coro_factory: Callable[[], Awaitable[None]]):
79174
"""Run coroutine factory."""
80175
try:

infscale/execution/pipeline.py

Lines changed: 23 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,6 @@ def __init__(
6868
self.wcomm = wcomm
6969
self.spec: ServeConfig = None
7070
self.device = None
71-
self.world_infos: dict[str, WorldInfo] = {}
7271
self.cfg_event = asyncio.Event()
7372
self._micro_batch_size = 1
7473
self._initialized = False
@@ -118,7 +117,9 @@ def _set_worker_status(self, status: WorkerStatus) -> None:
118117
"""Set worker status in pipeline and channel."""
119118
self._status = status
120119

121-
for world_info in self.world_infos.values():
120+
world_infos = self.config_runner.get_world_infos()
121+
122+
for world_info in world_infos.values():
122123
world_info.channel.set_worker_status(status)
123124

124125
def _set_n_send_worker_status(self, status: WorkerStatus) -> None:
@@ -143,52 +144,42 @@ def _reset_control_channel(self, world_info: WorldInfo) -> None:
143144

144145
async def _cleanup_recovered_worlds(self) -> None:
145146
"""Clean up world infos for recovered worlds."""
147+
world_infos = self.config_runner.get_world_infos()
148+
146149
# if I'm the recovered worker, return
147-
if len(self.world_infos) == 0:
150+
if len(world_infos) == 0:
148151
return
149152

150153
recover_worlds = [
151154
world_info
152155
for world_list in self.spec.flow_graph.values()
153156
for world_info in world_list
154-
if world_info.recover and world_info.name in self.world_infos
157+
if world_info.recover and world_info.name in world_infos
155158
]
156159

157160
# no worlds to recover
158161
if len(recover_worlds) == 0:
159162
return
160163

161164
for world_info in recover_worlds:
162-
wi = self.world_infos.get(world_info.name, None)
165+
wi = world_infos.get(world_info.name, None)
163166

164167
await self.router.cleanup_world(wi)
165168
self._reset_control_channel(wi)
166169
self._reset_multiworld(wi)
167170

168-
del self.world_infos[wi.name]
171+
self.config_runner.remove_world_info(wi.name)
169172

170173
async def _configure(self) -> None:
171174
"""(Re)configure multiworld, control channel and router."""
172-
await self._cleanup_recovered_worlds()
173-
174-
is_first_run = not self.world_infos
175+
is_first_run = self.config_runner.is_first_run()
175176

176177
if not is_first_run:
177178
self._set_worker_status(WorkerStatus.UPDATING)
178179

179-
new_world_infos = self._build_world_infos()
180-
new = new_world_infos.keys()
181-
cur = self.world_infos.keys()
182-
183-
worlds_to_add = [new_world_infos[name] for name in new - cur]
184-
worlds_to_remove = [self.world_infos[name] for name in cur - new]
185-
186-
self.config_runner.set_worlds_to_configure(new - cur)
187-
188-
# handle new worlds
189180
tasks = []
190181
# 1. set up control channel
191-
for world_info in worlds_to_add:
182+
for world_info in self.config_runner._worlds_to_add:
192183
task = self._configure_control_channel(world_info)
193184
tasks.append(task)
194185

@@ -198,7 +189,7 @@ async def _configure(self) -> None:
198189

199190
tasks = []
200191
# 2. set up multiworld
201-
for world_info in worlds_to_add:
192+
for world_info in self.config_runner._worlds_to_add:
202193
task = self._configure_multiworld(world_info)
203194
tasks.append(task)
204195

@@ -207,23 +198,25 @@ async def _configure(self) -> None:
207198
await asyncio.gather(*tasks)
208199

209200
# update world_info for added worlds
210-
for world_info in worlds_to_add:
211-
self.world_infos[world_info.name] = world_info
201+
self.config_runner.set_world_infos()
212202

213203
# configure router with worlds to add and remove
214204
await self.router.configure(
215-
self.spec, self.device, worlds_to_add, worlds_to_remove
205+
self.spec,
206+
self.device,
207+
self.config_runner._worlds_to_add,
208+
self.config_runner._worlds_to_remove,
216209
)
217210

218211
# handle unnecessary world
219212
# remove is executed in the reverse order of add
220-
for world_info in worlds_to_remove:
213+
for world_info in self.config_runner._worlds_to_remove:
221214
# 1. remove unnecessary world from control channel
222215
self._reset_control_channel(world_info)
223216
# 2. remove unnecessary world from multiworld
224217
self._reset_multiworld(world_info)
225218

226-
del self.world_infos[world_info.name]
219+
self.config_runner.remove_world_info(world_info.name)
227220

228221
worker_status = WorkerStatus.RUNNING if is_first_run else WorkerStatus.UPDATED
229222

@@ -426,10 +419,12 @@ async def _handle_config(self, spec: ServeConfig) -> None:
426419
if spec is None:
427420
return
428421

429-
self.config_runner.handle_new_spec(spec)
430-
431422
self._configure_variables(spec)
432423

424+
await self._cleanup_recovered_worlds()
425+
426+
self.config_runner.handle_new_spec(spec)
427+
433428
self._inspector.configure(self.spec)
434429

435430
self._initialize_once()

0 commit comments

Comments
 (0)