From 74d67187be9e167856558bee9d4e9708ecdd4f3a Mon Sep 17 00:00:00 2001 From: Stephen Macke Date: Fri, 27 Oct 2023 07:56:02 -0700 Subject: [PATCH] self reloadable --- core/superduperreload/patching.py | 53 ++++++++++------------- core/superduperreload/superduperreload.py | 10 +++-- core/superduperreload/utils.py | 19 ++++++++ 3 files changed, 49 insertions(+), 33 deletions(-) create mode 100644 core/superduperreload/utils.py diff --git a/core/superduperreload/patching.py b/core/superduperreload/patching.py index 2955f32..d8ae0b2 100644 --- a/core/superduperreload/patching.py +++ b/core/superduperreload/patching.py @@ -7,6 +7,8 @@ from types import FunctionType, MethodType from typing import Callable, Dict, List, Optional, Set, Sized, Tuple, Type, Union +from superduperreload.utils import isinstance2 + if sys.maxsize > 2**32: WORD_TYPE: Union[Type[ctypes.c_int32], Type[ctypes.c_int64]] = ctypes.c_int64 WORD_N_BYTES = 8 @@ -15,10 +17,6 @@ WORD_N_BYTES = 4 -def isinstance2(a, b, typ): - return isinstance(a, typ) and isinstance(b, typ) - - # Placeholder for indicating an attribute is not found _NOT_FOUND: object = object() @@ -53,11 +51,6 @@ class _CPythonStructType(Enum): PARTIALMETHOD = "partialmethod" -_FIELD_OFFSET_LOOKUP_TABLE_BY_STRUCT_TYPE: Dict[_CPythonStructType, Dict[str, int]] = { - field_type: {} for field_type in _CPythonStructType -} - - _MAX_FIELD_SEARCH_OFFSET = 50 _MAX_REFERRERS_FOR_PATCHING = 512 _MAX_REFERRER_LENGTH_FOR_PATCHING = 128 @@ -82,6 +75,8 @@ class _CPythonStructType(Enum): class ObjectPatcher: + _FIELD_OFFSET_LOOKUP_TABLE_BY_STRUCT_TYPE: Dict[str, Dict[str, int]] = {} + def __init__(self, patch_referrers: bool) -> None: self._patched_obj_ids: Set[int] = set() self._patch_rules = [ @@ -114,7 +109,9 @@ def _infer_field_offset( if field_value is _NOT_FOUND: return -1 if cache: - offset_tab = _FIELD_OFFSET_LOOKUP_TABLE_BY_STRUCT_TYPE[struct_type] + offset_tab = cls._FIELD_OFFSET_LOOKUP_TABLE_BY_STRUCT_TYPE.setdefault( + struct_type.value, {} + ) else: offset_tab = {} ret = offset_tab.get(field) @@ -188,7 +185,6 @@ def _try_patch_readonly_attr( cls._try_write_readonly_attr(struct_type, old, field, new_value, offset=offset) def _patch_function(self, old, new): - """Upgrade the code object of a function""" if old is new: return for name in _FUNC_ATTRS: @@ -223,13 +219,8 @@ def _patch_class_members(self, old: Type[object], new: Type[object]) -> None: for key in list(old.__dict__.keys()): old_obj = getattr(old, key) new_obj = getattr(new, key, _NOT_FOUND) - try: - if (old_obj == new_obj) is True: - continue - except ValueError: - # can't compare nested structures containing - # numpy arrays using `==` - pass + if old_obj is new_obj: + continue if new_obj is _NOT_FOUND and isinstance(old_obj, ClassCallableTypes): # obsolete attribute: remove it try: @@ -295,6 +286,18 @@ def _patch_partialmethod( _CPythonStructType.PARTIALMETHOD, old, new, "keywords" ) + def _patch_generic(self, old: object, new: object) -> None: + if old is new: + return + old_id = id(old) + if old_id in self._patched_obj_ids: + return + self._patched_obj_ids.add(old_id) + for type_check, patch in self._patch_rules: + if type_check(old, new): + patch(old, new) + break + def _patch_list_referrer(self, ref: List[object], old: object, new: object) -> None: for i, obj in enumerate(list(ref)): if obj is old: @@ -304,6 +307,8 @@ def _patch_dict_referrer( self, ref: Dict[object, object], old: object, new: object ) -> None: # reinsert everything in the dict in iteration order, updating refs of 'old' to 'new' + # if hasattr(old, "__class__") and issubclass(old.__class__, Enum): + # print(old, new, ref) for k, v in dict(ref).items(): if k is old: del ref[k] @@ -313,18 +318,6 @@ def _patch_dict_referrer( else: ref[k] = v - def _patch_generic(self, old: object, new: object) -> None: - if old is new: - return - old_id = id(old) - if old_id in self._patched_obj_ids: - return - self._patched_obj_ids.add(old_id) - for type_check, patch in self._patch_rules: - if type_check(old, new): - patch(old, new) - break - def _patch_referrers_generic(self, old: object, new: object) -> None: if not self._patch_referrers: return diff --git a/core/superduperreload/superduperreload.py b/core/superduperreload/superduperreload.py index 25cbe46..e2c0ab4 100644 --- a/core/superduperreload/superduperreload.py +++ b/core/superduperreload/superduperreload.py @@ -44,6 +44,7 @@ from superduperreload.functional_reload import exec_module_for_new_dict from superduperreload.patching import IMMUTABLE_PRIMITIVE_TYPES, ObjectPatcher +from superduperreload.utils import print_purple if TYPE_CHECKING: from IPython import InteractiveShell @@ -76,6 +77,8 @@ def __init__( super().__init__(patch_referrers=SHOULD_PATCH_REFERRERS) # Whether this reloader is enabled self.enabled = True + # Whether to print reloaded modules and other messages + self.verbose = True # Modules that failed to reload: {module: mtime-on-failed-reload, ...} self.failed: Dict[str, float] = {} # Modules specially marked as not autoreloadable. @@ -99,12 +102,13 @@ def __init__( self.reloaded_modules: List[str] = [] self.failed_modules: List[str] = [] - # Reporting callable for verbosity - self._report = lambda msg: None # by default, be quiet. - # Cache module modification times self.check(do_reload=False) + def _report(self, msg: str) -> None: + if self.verbose: + print_purple(msg) + def mark_module_skipped(self, module_name: str) -> None: """Skip reloading the named module in the future""" self.skip_modules.add(module_name) diff --git a/core/superduperreload/utils.py b/core/superduperreload/utils.py new file mode 100644 index 0000000..1eb48c9 --- /dev/null +++ b/core/superduperreload/utils.py @@ -0,0 +1,19 @@ +# -*- coding: utf-8 -*- + +_PURPLE = "\033[95m" +_RESET = "\033[0m" + + +def print_purple(text: str) -> None: + # The ANSI escape code for purple text is \033[95m + # The \033 is the escape code, and [95m specifies the color (purple) + # Reset code is \033[0m that resets the style to default + print(f"{_PURPLE}{text}{_RESET}") + + +def isinstance2(a, b, typ): + return isinstance(a, typ) and isinstance(b, typ) + + +def issubclass2(a, b, typ): + return issubclass(a, typ) and issubclass(b, typ)