Skip to content

Commit ab1961b

Browse files
committed
working on ipyflow integration
1 parent d9b5f5b commit ab1961b

File tree

2 files changed

+52
-24
lines changed

2 files changed

+52
-24
lines changed

core/superduperreload/patching.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ class ObjectPatcher:
7979

8080
def __init__(self, patch_referrers: bool) -> None:
8181
self._patched_obj_ids: Set[int] = set()
82-
self._patch_rules = [
82+
self._patch_rules: List[Tuple[Callable, Callable]] = [
8383
(lambda a, b: isinstance2(a, b, type), self._patch_class),
8484
(lambda a, b: isinstance2(a, b, FunctionType), self._patch_function),
8585
(lambda a, b: isinstance2(a, b, MethodType), self._patch_method),
@@ -184,25 +184,31 @@ def _try_patch_readonly_attr(
184184
offset = cls._infer_field_offset(struct_type, new, field)
185185
cls._try_write_readonly_attr(struct_type, old, field, new_value, offset=offset)
186186

187-
def _patch_function(self, old, new):
187+
def _patch_function(self, old: FunctionType, new: FunctionType) -> None:
188188
if old is new:
189189
return
190190
for name in _FUNC_ATTRS:
191+
if name == "__globals__":
192+
# switch order for __globals__ since we keep the old module.__dict__
193+
old, new = new, old
191194
try:
192195
setattr(old, name, getattr(new, name))
193196
except (AttributeError, TypeError, ValueError):
194197
self._try_patch_readonly_attr(
195198
_CPythonStructType.FUNCTION, old, new, name
196199
)
200+
finally:
201+
if name == "__globals__":
202+
old, new = new, old
197203

198-
def _patch_method(self, old: MethodType, new: MethodType):
204+
def _patch_method(self, old: MethodType, new: MethodType) -> None:
199205
if old is new:
200206
return
201207
self._patch_function(old.__func__, new.__func__)
202208
self._try_patch_readonly_attr(_CPythonStructType.METHOD, old, new, "__self__")
203209

204210
@classmethod
205-
def _patch_instances(cls, old, new):
211+
def _patch_instances(cls, old: Type[object], new: Type[object]) -> None:
206212
"""Use garbage collector to find all instances that refer to the old
207213
class definition and update their __class__ to point to the new class
208214
definition"""
@@ -269,7 +275,7 @@ def _patch_property(self, old: property, new: property) -> None:
269275
def _patch_partial(self, old: functools.partial, new: functools.partial) -> None:
270276
if old is new:
271277
return
272-
self._patch_function(old.func, new.func)
278+
self._patch_function(old.func, new.func) # type: ignore
273279
self._try_patch_readonly_attr(_CPythonStructType.PARTIAL, old, new, "args")
274280
self._try_patch_readonly_attr(_CPythonStructType.PARTIAL, old, new, "keywords")
275281

core/superduperreload/superduperreload.py

+41-19
Original file line numberDiff line numberDiff line change
@@ -162,39 +162,41 @@ def filename_and_mtime(
162162

163163
return py_filename, pymtime
164164

165-
def check(self, do_reload: bool = True) -> None:
166-
"""Check whether some modules need to be reloaded."""
167-
self.reloaded_modules.clear()
168-
self.failed_modules.clear()
169-
170-
# TODO: we should try to reload the modules in topological order
165+
def _get_modules_needing_reload(self) -> Dict[str, Tuple[ModuleType, str, float]]:
166+
modules_needing_reload = {}
171167
for modname, m in list(sys.modules.items()):
172168
package_components = modname.split(".")
173169
if any(
174170
".".join(package_components[:idx]) in self.skip_modules
175171
for idx in range(1, len(package_components))
176172
):
177173
continue
178-
179174
py_filename, pymtime = self.filename_and_mtime(m)
180175
if py_filename is None:
181176
continue
182-
183-
try:
184-
if pymtime <= self.modules_mtimes[modname]:
185-
continue
186-
except KeyError:
187-
self.modules_mtimes[modname] = pymtime
177+
if pymtime <= self.modules_mtimes.setdefault(modname, pymtime):
178+
continue
179+
if self.failed.get(py_filename) == pymtime:
188180
continue
189-
else:
190-
if self.failed.get(py_filename, None) == pymtime:
191-
continue
192-
193181
self.modules_mtimes[modname] = pymtime
182+
modules_needing_reload[modname] = (m, py_filename, pymtime)
183+
return modules_needing_reload
194184

195-
if not do_reload:
196-
continue
185+
def check(self, do_reload: bool = True) -> None:
186+
"""Check whether some modules need to be reloaded."""
187+
self.reloaded_modules.clear()
188+
self.failed_modules.clear()
197189

190+
modules_needing_reload = self._get_modules_needing_reload()
191+
if not do_reload:
192+
return
193+
194+
# TODO: we should try to reload the modules in topological order
195+
for modname, (
196+
m,
197+
py_filename,
198+
pymtime,
199+
) in modules_needing_reload.items():
198200
# If we've reached this point, we should try to reload the module
199201
self._report(f"Reloading '{modname}'.")
200202
try:
@@ -220,6 +222,18 @@ def maybe_track_obj(self, module: ModuleType, name: str, obj: object) -> None:
220222
except TypeError:
221223
pass
222224

225+
def _patch_ipyflow_symbols(self, old, new, flow_):
226+
if flow_ is None:
227+
return
228+
if isinstance(old, IMMUTABLE_PRIMITIVE_TYPES):
229+
return
230+
old_id = id(old)
231+
if old_id not in flow_.aliases:
232+
return
233+
for sym in list(flow_.aliases[old_id]):
234+
sym._override_ready_liveness_cell_num = flow_.cell_counter()
235+
sym.update_obj_ref(new)
236+
223237
def superduperreload(self, module: ModuleType) -> ModuleType:
224238
"""Enhanced version of the superreload function from IPython's autoreload extension.
225239
@@ -229,6 +243,13 @@ def superduperreload(self, module: ModuleType) -> ModuleType:
229243
- upgrades the code object of every old function and method
230244
- clears the module's namespace before reloading
231245
"""
246+
try:
247+
from ipyflow import flow
248+
249+
flow_ = flow()
250+
except:
251+
flow_ = None
252+
232253
self._patched_obj_ids.clear()
233254

234255
# collect old objects in the module
@@ -254,6 +275,7 @@ def superduperreload(self, module: ModuleType) -> ModuleType:
254275
continue
255276
self._patch_generic(old_obj, new_obj)
256277
self._patch_referrers_generic(old_obj, new_obj)
278+
self._patch_ipyflow_symbols(old_obj, new_obj, flow_)
257279

258280
if new_refs:
259281
self.old_objects[key] = new_refs

0 commit comments

Comments
 (0)