From 3adb4bd05e4fed8f4f2ff910632842cbd9cf967a Mon Sep 17 00:00:00 2001 From: aasiffaizal <38973423+aasiffaizal@users.noreply.github.com> Date: Mon, 2 Dec 2024 09:52:55 -0500 Subject: [PATCH 1/2] Fix: Class disambiguation logic --- bump_pydantic/codemods/class_def_visitor.py | 36 +++++---- tests/integration/cases/__init__.py | 2 + tests/integration/cases/nested_inheritance.py | 77 +++++++++++++++++++ tests/unit/test_add_annotations.py | 6 +- tests/unit/test_add_default_none.py | 6 +- 5 files changed, 107 insertions(+), 20 deletions(-) create mode 100644 tests/integration/cases/nested_inheritance.py diff --git a/bump_pydantic/codemods/class_def_visitor.py b/bump_pydantic/codemods/class_def_visitor.py index b97aa0e..991c861 100644 --- a/bump_pydantic/codemods/class_def_visitor.py +++ b/bump_pydantic/codemods/class_def_visitor.py @@ -38,6 +38,14 @@ def __init__(self, context: CodemodContext) -> None: self.context.scratch.setdefault(self.NO_BASE_MODEL_CONTEXT_KEY, set()) self.context.scratch.setdefault(self.CLS_CONTEXT_KEY, defaultdict(set)) + def _recursively_disambiguate( + self, classname: str, context_set: set[str], ambiguous_classes: dict[str, set[str]] + ) -> None: + if classname in context_set and classname in ambiguous_classes: + for child_classname in ambiguous_classes.pop(classname): + context_set.add(child_classname) + self._recursively_disambiguate(child_classname, context_set, ambiguous_classes) + def visit_ClassDef(self, node: cst.ClassDef) -> None: fqn_set = self.get_metadata(FullyQualifiedNameProvider, node) @@ -60,30 +68,30 @@ def visit_ClassDef(self, node: cst.ClassDef) -> None: self.context.scratch[self.NO_BASE_MODEL_CONTEXT_KEY].add(fqn.name) # In case we have the following scenario: + # class ChildA(A): # class A(B): ... # class B(BaseModel): ... # class D(C): ... # class C: ... - # We want to disambiguate `A` as soon as we see `B` is a `BaseModel`. - if ( - fqn.name in self.context.scratch[self.BASE_MODEL_CONTEXT_KEY] - and fqn.name in self.context.scratch[self.CLS_CONTEXT_KEY] - ): - for parent_class in self.context.scratch[self.CLS_CONTEXT_KEY].pop(fqn.name): - self.context.scratch[self.BASE_MODEL_CONTEXT_KEY].add(parent_class) + # We want to disambiguate `A` and then `ChildA` as soon as we see `B` is a `BaseModel`. + # We recursively add child classes to self.BASE_MODEL_CONTEXT_KEY. + self._recursively_disambiguate( + fqn.name, self.context.scratch[self.BASE_MODEL_CONTEXT_KEY], self.context.scratch[self.CLS_CONTEXT_KEY] + ) # In case we have the following scenario: # class A(B): ... # class B(BaseModel): ... + # class E(D): ... # class D(C): ... # class C: ... - # We want to disambiguate `D` as soon as we see `C` is NOT a `BaseModel`. - if ( - fqn.name in self.context.scratch[self.NO_BASE_MODEL_CONTEXT_KEY] - and fqn.name in self.context.scratch[self.CLS_CONTEXT_KEY] - ): - for parent_class in self.context.scratch[self.CLS_CONTEXT_KEY].pop(fqn.name): - self.context.scratch[self.NO_BASE_MODEL_CONTEXT_KEY].add(parent_class) + # We want to disambiguate `D` and then `E` as soon as we see `C` is NOT a `BaseModel`. + # We recursively add child classes to self.NO_BASE_MODEL_CONTEXT_KEY. + self._recursively_disambiguate( + fqn.name, + self.context.scratch[self.NO_BASE_MODEL_CONTEXT_KEY], + self.context.scratch[self.CLS_CONTEXT_KEY], + ) # In case we have the following scenario: # class A(B): ... diff --git a/tests/integration/cases/__init__.py b/tests/integration/cases/__init__.py index ecb2ffc..c2ae14b 100644 --- a/tests/integration/cases/__init__.py +++ b/tests/integration/cases/__init__.py @@ -9,6 +9,7 @@ from .field import cases as generic_model_cases from .folder_inside_folder import cases as folder_inside_folder_cases from .is_base_model import cases as is_base_model_cases +from .nested_inheritance import cases as nested_inheritance_cases from .replace_validator import cases as replace_validator_cases from .root_model import cases as root_model_cases from .unicode import cases as unicode_cases @@ -22,6 +23,7 @@ *base_settings_cases, *add_none_cases, *is_base_model_cases, + *nested_inheritance_cases, *replace_validator_cases, *config_to_model_cases, *root_model_cases, diff --git a/tests/integration/cases/nested_inheritance.py b/tests/integration/cases/nested_inheritance.py new file mode 100644 index 0000000..a1cdec4 --- /dev/null +++ b/tests/integration/cases/nested_inheritance.py @@ -0,0 +1,77 @@ +from ..case import Case +from ..file import File +from ..folder import Folder + +cases = [ + Case( + name="Nested Inheritance", + source=Folder( + "nested_inheritance", + File("__init__.py", content=[]), + File( + "bar.py", + content=[ + "from .foo import Foo", + "", + "", + "class Bar(Foo):", + " b: str | None", + ], + ), + File( + "baz.py", + content=[ + "from .bar import Bar", + "", + "", + "class Baz(Bar):", + " c: str | None", + ], + ), + File( + "foo.py", + content=[ + "from pydantic import BaseModel", + "", + "", + "class Foo(BaseModel):", + " a: str | None", + ], + ), + ), + expected=Folder( + "nested_inheritance", + File("__init__.py", content=[]), + File( + "bar.py", + content=[ + "from .foo import Foo", + "", + "", + "class Bar(Foo):", + " b: str | None = None", + ], + ), + File( + "baz.py", + content=[ + "from .bar import Bar", + "", + "", + "class Baz(Bar):", + " c: str | None = None", + ], + ), + File( + "foo.py", + content=[ + "from pydantic import BaseModel", + "", + "", + "class Foo(BaseModel):", + " a: str | None = None", + ], + ), + ), + ) +] diff --git a/tests/unit/test_add_annotations.py b/tests/unit/test_add_annotations.py index 2750a96..fd05dd3 100644 --- a/tests/unit/test_add_annotations.py +++ b/tests/unit/test_add_annotations.py @@ -16,9 +16,9 @@ def add_annotations(self, file_path: str, code: str) -> cst.Module: mod = MetadataWrapper( parse_module(CodemodTest.make_fixture_data(code)), cache={ - FullyQualifiedNameProvider: FullyQualifiedNameProvider.gen_cache(Path(""), [file_path], None).get( - file_path, "" - ) + FullyQualifiedNameProvider: FullyQualifiedNameProvider.gen_cache( + Path(""), [file_path], timeout=None + ).get(file_path, "") }, ) mod.resolve_many(AddAnnotationsCommand.METADATA_DEPENDENCIES) diff --git a/tests/unit/test_add_default_none.py b/tests/unit/test_add_default_none.py index a0f1202..3a4f28e 100644 --- a/tests/unit/test_add_default_none.py +++ b/tests/unit/test_add_default_none.py @@ -17,9 +17,9 @@ def add_default_none(self, file_path: str, code: str) -> cst.Module: mod = MetadataWrapper( parse_module(CodemodTest.make_fixture_data(code)), cache={ - FullyQualifiedNameProvider: FullyQualifiedNameProvider.gen_cache(Path(""), [file_path], None).get( - file_path, "" - ) + FullyQualifiedNameProvider: FullyQualifiedNameProvider.gen_cache( + Path(""), [file_path], timeout=None + ).get(file_path, "") }, ) mod.resolve_many(AddDefaultNoneCommand.METADATA_DEPENDENCIES) From 01de55d502fd0d89f464678324a7e15ee1055a21 Mon Sep 17 00:00:00 2001 From: aasiffaizal <38973423+aasiffaizal@users.noreply.github.com> Date: Mon, 2 Dec 2024 10:45:58 -0500 Subject: [PATCH 2/2] refactor function --- bump_pydantic/codemods/class_def_visitor.py | 20 ++++++-------------- 1 file changed, 6 insertions(+), 14 deletions(-) diff --git a/bump_pydantic/codemods/class_def_visitor.py b/bump_pydantic/codemods/class_def_visitor.py index 991c861..1ad73e2 100644 --- a/bump_pydantic/codemods/class_def_visitor.py +++ b/bump_pydantic/codemods/class_def_visitor.py @@ -38,13 +38,11 @@ def __init__(self, context: CodemodContext) -> None: self.context.scratch.setdefault(self.NO_BASE_MODEL_CONTEXT_KEY, set()) self.context.scratch.setdefault(self.CLS_CONTEXT_KEY, defaultdict(set)) - def _recursively_disambiguate( - self, classname: str, context_set: set[str], ambiguous_classes: dict[str, set[str]] - ) -> None: - if classname in context_set and classname in ambiguous_classes: - for child_classname in ambiguous_classes.pop(classname): + def _recursively_disambiguate(self, classname: str, context_set: set[str]) -> None: + if classname in context_set and classname in self.context.scratch[self.CLS_CONTEXT_KEY]: + for child_classname in self.context.scratch[self.CLS_CONTEXT_KEY].pop(classname): context_set.add(child_classname) - self._recursively_disambiguate(child_classname, context_set, ambiguous_classes) + self._recursively_disambiguate(child_classname, context_set) def visit_ClassDef(self, node: cst.ClassDef) -> None: fqn_set = self.get_metadata(FullyQualifiedNameProvider, node) @@ -75,9 +73,7 @@ def visit_ClassDef(self, node: cst.ClassDef) -> None: # class C: ... # We want to disambiguate `A` and then `ChildA` as soon as we see `B` is a `BaseModel`. # We recursively add child classes to self.BASE_MODEL_CONTEXT_KEY. - self._recursively_disambiguate( - fqn.name, self.context.scratch[self.BASE_MODEL_CONTEXT_KEY], self.context.scratch[self.CLS_CONTEXT_KEY] - ) + self._recursively_disambiguate(fqn.name, self.context.scratch[self.BASE_MODEL_CONTEXT_KEY]) # In case we have the following scenario: # class A(B): ... @@ -87,11 +83,7 @@ def visit_ClassDef(self, node: cst.ClassDef) -> None: # class C: ... # We want to disambiguate `D` and then `E` as soon as we see `C` is NOT a `BaseModel`. # We recursively add child classes to self.NO_BASE_MODEL_CONTEXT_KEY. - self._recursively_disambiguate( - fqn.name, - self.context.scratch[self.NO_BASE_MODEL_CONTEXT_KEY], - self.context.scratch[self.CLS_CONTEXT_KEY], - ) + self._recursively_disambiguate(fqn.name, self.context.scratch[self.NO_BASE_MODEL_CONTEXT_KEY]) # In case we have the following scenario: # class A(B): ...