2020from typing import Awaitable , Callable
2121
2222from infscale .configs .job import ServeConfig
23+ from infscale .execution .control import Channel as CtrlCh
24+ from infscale .execution .world import WorldInfo
2325
2426
2527class ConfigRunner :
@@ -34,11 +36,21 @@ def __init__(self):
3436 self ._event .set () # initially no configure running
3537 self ._curr_worlds_to_configure : set [str ] = set ()
3638 self ._cancel_cur_cfg = False
39+ self ._worlds_to_add : list [WorldInfo ]
40+ self ._worlds_to_remove : list [WorldInfo ]
41+ self ._world_infos : dict [str , WorldInfo ] = {}
3742
3843 def handle_new_spec (self , spec : ServeConfig ) -> None :
3944 """Handle new spec."""
4045 self ._cancel_cur_cfg = self ._should_cancel_current (spec )
4146 self ._spec = spec
47+ new_world_infos = self ._build_world_infos ()
48+
49+ new = new_world_infos .keys ()
50+ cur = self ._world_infos .keys ()
51+
52+ self ._worlds_to_add = self .get_words_to_add (new_world_infos , new , cur )
53+ self ._worlds_to_remove = self .get_words_to_remove (self ._world_infos , new , cur )
4254
4355 def _should_cancel_current (self , spec : ServeConfig ) -> bool :
4456 """Decide if current configuration should be cancelled."""
@@ -57,6 +69,35 @@ def set_worlds_to_configure(self, world_names: set[str]) -> None:
5769 """Set the world names currently being configured."""
5870 self ._curr_worlds_to_configure = world_names
5971
72+ def set_world_infos (self ) -> None :
73+ """Set new world infos."""
74+ for world_info in self ._worlds_to_add :
75+ self ._world_infos [world_info .name ] = world_info
76+
77+ def get_world_infos (self ) -> dict [str , WorldInfo ]:
78+ "Get world infos."
79+ return self ._world_infos
80+
81+ def is_first_run (self ) -> bool :
82+ "Return boolean if is first run or not."
83+ return not self ._world_infos
84+
85+ def remove_world_info (self , world_name : str ) -> None :
86+ """Remove world info by name."""
87+ del self ._world_infos [world_name ]
88+
89+ def get_words_to_add (
90+ self , world_infos : list [WorldInfo ], new : set [str ], cur : set [str ]
91+ ) -> list [WorldInfo ]:
92+ """Return a list of world infos to add."""
93+ return [world_infos [name ] for name in new - cur ]
94+
95+ def get_words_to_remove (
96+ self , world_infos : list [WorldInfo ], new : set [str ], cur : set [str ]
97+ ) -> list [WorldInfo ]:
98+ """Return a list of world infos to remove."""
99+ return [world_infos [name ] for name in cur - new ]
100+
60101 async def schedule (self , coro_factory : Callable [[], Awaitable [None ]]):
61102 """Cancel any in-progress configure and schedule a new one."""
62103 # wait for current to finish if we do not want to cancel
@@ -75,6 +116,60 @@ async def schedule(self, coro_factory: Callable[[], Awaitable[None]]):
75116 self ._event .clear ()
76117 self ._task = self ._loop .create_task (self ._run (coro_factory ))
77118
119+ def _build_world_infos (self ) -> dict [str , WorldInfo ]:
120+ world_infos : dict [str , WorldInfo ] = {}
121+
122+ my_id = self ._spec .stage .id
123+ for k , v in self ._spec .flow_graph .items ():
124+ for cfg_world_info in v :
125+ # NOTE: no. of peers is always 1 for now
126+ assert len (cfg_world_info .peers ) == 1
127+
128+ if my_id == k :
129+ my_rank = 0
130+ other_rank = 1
131+ other_id = cfg_world_info .peers [0 ]
132+ elif my_id in cfg_world_info .peers :
133+ # NOTE: this is always 1 for now
134+ my_rank = cfg_world_info .peers .index (my_id ) + 1
135+ other_rank = 0
136+ other_id = k
137+ else :
138+ continue
139+
140+ name , backend , addr , data_port , ctrl_port , recover , conflict_count = (
141+ cfg_world_info .name ,
142+ cfg_world_info .backend ,
143+ cfg_world_info .addr ,
144+ cfg_world_info .data_port ,
145+ cfg_world_info .ctrl_port ,
146+ cfg_world_info .recover ,
147+ cfg_world_info .conflict_count ,
148+ )
149+
150+ world_size = len (cfg_world_info .peers ) + 1
151+ ctrl_ch = CtrlCh (my_rank , world_size , addr , ctrl_port )
152+
153+ data = {
154+ "name" : name ,
155+ "size" : world_size ,
156+ "addr" : addr ,
157+ "port" : data_port ,
158+ "backend" : backend ,
159+ "channel" : ctrl_ch ,
160+ "my_id" : my_id ,
161+ "me" : my_rank ,
162+ "other_id" : other_id ,
163+ "other" : other_rank ,
164+ "recover" : recover ,
165+ "conflict_count" : conflict_count ,
166+ "multiworld_name" : f"{ name } -{ conflict_count } " ,
167+ }
168+ world_info = WorldInfo (** data )
169+ world_infos [name ] = world_info
170+
171+ return world_infos
172+
78173 async def _run (self , coro_factory : Callable [[], Awaitable [None ]]):
79174 """Run coroutine factory."""
80175 try :
0 commit comments