Skip to content

Commit e4f6da9

Browse files
hjh0119tohtana
andauthored
[bugfix] fix partition context unpatch (#7566)
## Fix asymmetric patching/unpatching in InsertPostInitMethodToModuleSubClasses ### Problem Description The `InsertPostInitMethodToModuleSubClasses` context manager patches `__init__` methods of model classes during entry and unpatches them during exit. However, asymmetric condition checks between patching and unpatching can introduce subtle inheritance bugs. ### Root Cause Analysis The issue occurs with classes that have multiple inheritance where: 1. **Child class A** does not override `__init__` 2. **Parent class B** does not inherit from `nn.Module` 3. **Parent class C** inherits from `nn.Module` **Current asymmetric logic:** ```python # Patching (entry): Only patch classes with explicit __init__ def _enable_class(cls): if '__init__' in cls.__dict__: # ✅ Strict check cls._old_init = cls.__init__ cls.__init__ = partition_after(cls.__init__) # Unpatching (exit): Restore any class with _old_init def _disable_class(cls): if hasattr(cls, '_old_init'): # ❌ Permissive check cls.__init__ = cls._old_init ``` **Execution flow:** 1. **During entry**: Child A is skipped (no explicit `__init__`), Parent C is patched 2. **During exit**: Child A inherits `_old_init` from Parent C and gets incorrectly "restored" **Result**: Child A's `__init__` points to Parent C's original `__init__`, bypassing Parent B and breaking the inheritance chain. ### Reproduction Case This pattern is common in Hugging Face models: ```python class Qwen3ForSequenceClassification(GenericForSequenceClassification, Qwen3PreTrainedModel): pass # No explicit __init__ # GenericForSequenceClassification - not a nn.Module subclass # Qwen3PreTrainedModel - inherits from nn.Module ``` ### Solution Apply symmetric condition checking in both patch and unpatch operations: ```python def _disable_class(cls): # Match the patching condition: only restore classes we explicitly patched if '__init__' in cls.__dict__ and hasattr(cls, '_old_init'): cls.__init__ = cls._old_init delattr(cls, '_old_init') # Optional cleanup ``` This ensures that only classes that were explicitly patched during entry get restored during exit. ### Testing The fix has been validated against the Qwen3ForSequenceClassification reproduction case and resolves the inheritance chain corruption. ### Related Issues - External issue: modelscope/ms-swift#5820 Co-authored-by: Masahiro Tanaka <[email protected]>
1 parent 6b731c5 commit e4f6da9

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

deepspeed/runtime/zero/partition_parameters.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -594,7 +594,7 @@ def unpatch_init_and_builtins(self):
594594
if self.patched:
595595

596596
def _disable_class(cls):
597-
if hasattr(cls, '_old_init'):
597+
if '__init__' in cls.__dict__ and hasattr(cls, '_old_init'):
598598
cls.__init__ = cls._old_init
599599

600600
for subclass in get_all_subclasses(torch.nn.modules.module.Module):

0 commit comments

Comments
 (0)