From ee0305062b8b2a659dbbe1b02d988d01f70c8ccb Mon Sep 17 00:00:00 2001 From: Kristian Nymann Jakobsen Date: Mon, 19 Sep 2022 21:39:29 +0200 Subject: [PATCH] :construction: WIP --- .../refactorings/extract_variable/refactor.py | 6 +- .../refactorings/extract_variable/visitor.py | 5 +- .../refactorings/inline_variable/refactor.py | 6 +- .../refactorings/inline_variable/visitor.py | 40 ++++++++++++- src/refacto/refactorings/refactor.py | 15 +++-- src/refacto/refactorings/visitor.py | 4 +- tests/unit_tests/test_get_node_from_range.py | 60 +++++++++++++++++++ 7 files changed, 119 insertions(+), 17 deletions(-) create mode 100644 tests/unit_tests/test_get_node_from_range.py diff --git a/src/refacto/refactorings/extract_variable/refactor.py b/src/refacto/refactorings/extract_variable/refactor.py index e38e53f..8ae15ff 100644 --- a/src/refacto/refactorings/extract_variable/refactor.py +++ b/src/refacto/refactorings/extract_variable/refactor.py @@ -8,13 +8,13 @@ class RefactorExtractVariable(Refactor): def __init__(self) -> None: - self.visitor: ExpressionFinder = ExpressionFinder() + super().__init__(visitor=ExpressionFinder) def create_transformer(self, selected_range: Range) -> RefactoringTransformer: - if self.visitor.expr is None: + if self.visitor.expr is None: # type:ignore raise RuntimeError("Node was unexpectically None!") return ExtractVariableTransformer( - expr=self.visitor.expr, + expr=self.visitor.expr, # type: ignore variable_name="rename_me", selected_range=selected_range, ) diff --git a/src/refacto/refactorings/extract_variable/visitor.py b/src/refacto/refactorings/extract_variable/visitor.py index aa64b15..2aa6087 100644 --- a/src/refacto/refactorings/extract_variable/visitor.py +++ b/src/refacto/refactorings/extract_variable/visitor.py @@ -1,12 +1,13 @@ import libcst as cst +from pygls.lsp.types.basic_structures import Range from refacto.refactorings.visitor import RefactoringVisitor class ExpressionFinder(RefactoringVisitor): - def __init__(self) -> None: - super().__init__() + def __init__(self, selected_range: Range) -> None: self.expr: cst.Expr | None = None + super().__init__(selected_range=selected_range) def visit_Expr(self, node: cst.Expr) -> bool: self.expr = node diff --git a/src/refacto/refactorings/inline_variable/refactor.py b/src/refacto/refactorings/inline_variable/refactor.py index 6ac90cc..61459ef 100644 --- a/src/refacto/refactorings/inline_variable/refactor.py +++ b/src/refacto/refactorings/inline_variable/refactor.py @@ -8,12 +8,12 @@ class RefactorInlineVariable(Refactor): def __init__(self) -> None: - self.visitor: NameFinder = NameFinder() + super().__init__(visitor=NameFinder) def create_transformer(self, selected_range: Range) -> RefactoringTransformer: - if self.visitor.name is None: + if self.visitor.name is None: # type: ignore raise RuntimeError("Couldn't find variable to inline :-(") return InlineVariableTransformer( - name=self.visitor.name, + name=self.visitor.name, # type: ignore selected_range=selected_range, ) diff --git a/src/refacto/refactorings/inline_variable/visitor.py b/src/refacto/refactorings/inline_variable/visitor.py index 5d52dd0..547c254 100644 --- a/src/refacto/refactorings/inline_variable/visitor.py +++ b/src/refacto/refactorings/inline_variable/visitor.py @@ -1,12 +1,48 @@ +from devtools import debug import libcst as cst +from libcst.metadata import ParentNodeProvider +from libcst.metadata import WhitespaceInclusivePositionProvider +from libcst.metadata.scope_provider import Scope +from libcst.metadata.scope_provider import ScopeProvider +from pygls.lsp.types.basic_structures import Range from refacto.refactorings.visitor import RefactoringVisitor class NameFinder(RefactoringVisitor): - def __init__(self) -> None: + METADATA_DEPENDENCIES = ( + WhitespaceInclusivePositionProvider, + ScopeProvider, + ParentNodeProvider, + ) + + def __init__(self, selected_range: Range) -> None: self.name: cst.Name | None = None + self.parent: cst.CSTNode | None = None + self.scope: Scope | None = None + super().__init__(selected_range=selected_range) def visit_Name(self, node: cst.Name) -> bool: + if self._is_same_starting_position(node=node): + self._set_things(node=node) + return False + return True + + def _set_things(self, node: cst.Name) -> None: self.name = node - return False + self.scope = self.get_metadata(ScopeProvider, node) + try: + self.parent = self.get_metadata(ParentNodeProvider, node) + except KeyError: + return + + def _is_same_starting_position(self, node: cst.Name) -> bool: + libcst_range = self.get_metadata(WhitespaceInclusivePositionProvider, node) + if node.value == "shipping_price": + debug(str(libcst_range), node) + return all( + [ + self.selected_range.start.line == libcst_range.start.line - 1, + self.selected_range.start.character == libcst_range.start.column, + ], + ) diff --git a/src/refacto/refactorings/refactor.py b/src/refacto/refactorings/refactor.py index fa7f888..026dfa1 100644 --- a/src/refacto/refactorings/refactor.py +++ b/src/refacto/refactorings/refactor.py @@ -1,5 +1,6 @@ from abc import ABC from abc import abstractmethod +from typing import Type import libcst as cst from pygls.lsp.types.basic_structures import Position @@ -11,17 +12,19 @@ class Refactor(ABC): - def __init__(self, visitor: RefactoringVisitor) -> None: - self.visitor = visitor + def __init__(self, visitor: Type[RefactoringVisitor]) -> None: + self.visitor_type = visitor + self.visitor: RefactoringVisitor | None = None def refactor(self, selected_range: Range, source: str) -> str: + self.visitor = self.visitor_type(selected_range=selected_range) selected_code = self.selected_code(selected_range=selected_range, source=source) self.visit_visitor(selected_code=selected_code) transformer = self.create_transformer(selected_range=selected_range) return self.get_transformed_code(transformer=transformer, source=source) - @staticmethod - def selected_code(selected_range: Range, source: str) -> str: # noqa: WPS602 + @classmethod + def selected_code(cls, selected_range: Range, source: str) -> str: # noqa: WPS602 start: Position = selected_range.start end: Position = selected_range.end end_char = end.character @@ -37,8 +40,8 @@ def create_transformer(self, selected_range: Range) -> RefactoringTransformer: raise NotImplementedError() def visit_visitor(self, selected_code: str) -> None: - range_tree = cst.parse_module(selected_code) - range_tree.visit(visitor=self.visitor) + range_tree = cst.MetadataWrapper(cst.parse_module(selected_code)) + range_tree.visit(visitor=self.visitor) # type: ignore def get_transformed_code(self, transformer: RefactoringTransformer, source: str) -> str: source_tree = cst.MetadataWrapper(cst.parse_module(source=source)) diff --git a/src/refacto/refactorings/visitor.py b/src/refacto/refactorings/visitor.py index 13391c1..f6edb7f 100644 --- a/src/refacto/refactorings/visitor.py +++ b/src/refacto/refactorings/visitor.py @@ -1,5 +1,7 @@ import libcst +from pygls.lsp.types.basic_structures import Range class RefactoringVisitor(libcst.CSTVisitor): - """Root Visitor for Refacto.""" + def __init__(self, selected_range: Range) -> None: + self.selected_range = selected_range diff --git a/tests/unit_tests/test_get_node_from_range.py b/tests/unit_tests/test_get_node_from_range.py new file mode 100644 index 0000000..d548f03 --- /dev/null +++ b/tests/unit_tests/test_get_node_from_range.py @@ -0,0 +1,60 @@ +import libcst as cst +from libcst.metadata.parent_node_provider import ParentNodeProvider +from libcst.metadata.position_provider import PositionProvider +from libcst.metadata.scope_provider import Scope +from libcst.metadata.scope_provider import ScopeProvider +from pygls.lsp.types.basic_structures import Position +from pygls.lsp.types.basic_structures import Range + + +class FindNode(cst.CSTVisitor): + METADATA_DEPENDENCIES = ( + PositionProvider, + ScopeProvider, + ParentNodeProvider, + ) + + def __init__(self, selected_range: Range) -> None: + self.selected_range = selected_range + self.node: cst.CSTNode | None = None + self.parent: cst.CSTNode | None = None + self.scope: Scope | None = None + + super().__init__() + + def visit_Name(self, node: cst.Name) -> bool: + if self._is_same_starting_position(node=node): + self._set_things(node=node) + return False + return True + + def _set_things(self, node: cst.CSTNode) -> None: + self.node = node + self.scope = self.get_metadata(ScopeProvider, node) + try: + self.parent = self.get_metadata(ParentNodeProvider, node) + except KeyError: + # No parent + return + + def _is_same_starting_position(self, node: cst.CSTNode) -> bool: + libcst_range = self.get_metadata(PositionProvider, node) + return all( + [ + self.selected_range.start.line == libcst_range.start.line - 1, + self.selected_range.start.character == libcst_range.start.column, + ], + ) + + +def test_stuff(): + selected_range = Range( + start=Position(line=0, character=0), + end=Position(line=0, character=1), + ) + with open("tests/test_cases/inline_variable/simplest/before.py", "r") as src: + code = src.read() + module = cst.MetadataWrapper(cst.parse_module(source=code)) + visitor = FindNode(selected_range=selected_range) + module.visit(visitor=visitor) + assert visitor.node is not None