You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
## Fix asymmetric patching/unpatching in
InsertPostInitMethodToModuleSubClasses
### Problem Description
The `InsertPostInitMethodToModuleSubClasses` context manager patches
`__init__` methods of model classes during entry and unpatches them
during exit.
However, asymmetric condition checks between patching and unpatching can
introduce subtle inheritance bugs.
### Root Cause Analysis
The issue occurs with classes that have multiple inheritance where:
1. **Child class A** does not override `__init__`
2. **Parent class B** does not inherit from `nn.Module`
3. **Parent class C** inherits from `nn.Module`
**Current asymmetric logic:**
```python
# Patching (entry): Only patch classes with explicit __init__
def _enable_class(cls):
if '__init__' in cls.__dict__: # ✅ Strict check
cls._old_init = cls.__init__
cls.__init__ = partition_after(cls.__init__)
# Unpatching (exit): Restore any class with _old_init
def _disable_class(cls):
if hasattr(cls, '_old_init'): # ❌ Permissive check
cls.__init__ = cls._old_init
```
**Execution flow:**
1. **During entry**: Child A is skipped (no explicit `__init__`), Parent
C is patched
2. **During exit**: Child A inherits `_old_init` from Parent C and gets
incorrectly "restored"
**Result**: Child A's `__init__` points to Parent C's original
`__init__`, bypassing Parent B and breaking the inheritance chain.
### Reproduction Case
This pattern is common in Hugging Face models:
```python
class Qwen3ForSequenceClassification(GenericForSequenceClassification, Qwen3PreTrainedModel):
pass # No explicit __init__
# GenericForSequenceClassification - not a nn.Module subclass
# Qwen3PreTrainedModel - inherits from nn.Module
```
### Solution
Apply symmetric condition checking in both patch and unpatch operations:
```python
def _disable_class(cls):
# Match the patching condition: only restore classes we explicitly patched
if '__init__' in cls.__dict__ and hasattr(cls, '_old_init'):
cls.__init__ = cls._old_init
delattr(cls, '_old_init') # Optional cleanup
```
This ensures that only classes that were explicitly patched during entry
get restored during exit.
### Testing
The fix has been validated against the Qwen3ForSequenceClassification
reproduction case and resolves the inheritance chain corruption.
### Related Issues
- External issue: modelscope/ms-swift#5820
Co-authored-by: Masahiro Tanaka <[email protected]>
0 commit comments