Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/refacto/refactorings/extract_variable/refactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@

class RefactorExtractVariable(Refactor):
def __init__(self) -> None:
self.visitor: ExpressionFinder = ExpressionFinder()
super().__init__(visitor=ExpressionFinder)

def create_transformer(self, selected_range: Range) -> RefactoringTransformer:
if self.visitor.expr is None:
if self.visitor.expr is None: # type:ignore
raise RuntimeError("Node was unexpectically None!")
return ExtractVariableTransformer(
expr=self.visitor.expr,
expr=self.visitor.expr, # type: ignore
variable_name="rename_me",
selected_range=selected_range,
)
5 changes: 3 additions & 2 deletions src/refacto/refactorings/extract_variable/visitor.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import libcst as cst
from pygls.lsp.types.basic_structures import Range

from refacto.refactorings.visitor import RefactoringVisitor


class ExpressionFinder(RefactoringVisitor):
def __init__(self) -> None:
super().__init__()
def __init__(self, selected_range: Range) -> None:
self.expr: cst.Expr | None = None
super().__init__(selected_range=selected_range)

def visit_Expr(self, node: cst.Expr) -> bool:
self.expr = node
Expand Down
6 changes: 3 additions & 3 deletions src/refacto/refactorings/inline_variable/refactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@

class RefactorInlineVariable(Refactor):
def __init__(self) -> None:
self.visitor: NameFinder = NameFinder()
super().__init__(visitor=NameFinder)

def create_transformer(self, selected_range: Range) -> RefactoringTransformer:
if self.visitor.name is None:
if self.visitor.name is None: # type: ignore
raise RuntimeError("Couldn't find variable to inline :-(")
return InlineVariableTransformer(
name=self.visitor.name,
name=self.visitor.name, # type: ignore
selected_range=selected_range,
)
40 changes: 38 additions & 2 deletions src/refacto/refactorings/inline_variable/visitor.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,48 @@
from devtools import debug
import libcst as cst
from libcst.metadata import ParentNodeProvider
from libcst.metadata import WhitespaceInclusivePositionProvider
from libcst.metadata.scope_provider import Scope
from libcst.metadata.scope_provider import ScopeProvider
from pygls.lsp.types.basic_structures import Range

from refacto.refactorings.visitor import RefactoringVisitor


class NameFinder(RefactoringVisitor):
def __init__(self) -> None:
METADATA_DEPENDENCIES = (
WhitespaceInclusivePositionProvider,
ScopeProvider,
ParentNodeProvider,
)

def __init__(self, selected_range: Range) -> None:
self.name: cst.Name | None = None
self.parent: cst.CSTNode | None = None
self.scope: Scope | None = None
super().__init__(selected_range=selected_range)

def visit_Name(self, node: cst.Name) -> bool:
if self._is_same_starting_position(node=node):
self._set_things(node=node)
return False
return True

def _set_things(self, node: cst.Name) -> None:
self.name = node
return False
self.scope = self.get_metadata(ScopeProvider, node)
try:
self.parent = self.get_metadata(ParentNodeProvider, node)
except KeyError:
return

def _is_same_starting_position(self, node: cst.Name) -> bool:
libcst_range = self.get_metadata(WhitespaceInclusivePositionProvider, node)
if node.value == "shipping_price":
debug(str(libcst_range), node)
return all(
[
self.selected_range.start.line == libcst_range.start.line - 1,
self.selected_range.start.character == libcst_range.start.column,
],
)
15 changes: 9 additions & 6 deletions src/refacto/refactorings/refactor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from abc import ABC
from abc import abstractmethod
from typing import Type

import libcst as cst
from pygls.lsp.types.basic_structures import Position
Expand All @@ -11,17 +12,19 @@


class Refactor(ABC):
def __init__(self, visitor: RefactoringVisitor) -> None:
self.visitor = visitor
def __init__(self, visitor: Type[RefactoringVisitor]) -> None:
self.visitor_type = visitor
self.visitor: RefactoringVisitor | None = None

def refactor(self, selected_range: Range, source: str) -> str:
self.visitor = self.visitor_type(selected_range=selected_range)
selected_code = self.selected_code(selected_range=selected_range, source=source)
self.visit_visitor(selected_code=selected_code)
transformer = self.create_transformer(selected_range=selected_range)
return self.get_transformed_code(transformer=transformer, source=source)

@staticmethod
def selected_code(selected_range: Range, source: str) -> str: # noqa: WPS602
@classmethod
def selected_code(cls, selected_range: Range, source: str) -> str: # noqa: WPS602
start: Position = selected_range.start
end: Position = selected_range.end
end_char = end.character
Expand All @@ -37,8 +40,8 @@ def create_transformer(self, selected_range: Range) -> RefactoringTransformer:
raise NotImplementedError()

def visit_visitor(self, selected_code: str) -> None:
range_tree = cst.parse_module(selected_code)
range_tree.visit(visitor=self.visitor)
range_tree = cst.MetadataWrapper(cst.parse_module(selected_code))
range_tree.visit(visitor=self.visitor) # type: ignore

def get_transformed_code(self, transformer: RefactoringTransformer, source: str) -> str:
source_tree = cst.MetadataWrapper(cst.parse_module(source=source))
Expand Down
4 changes: 3 additions & 1 deletion src/refacto/refactorings/visitor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import libcst
from pygls.lsp.types.basic_structures import Range


class RefactoringVisitor(libcst.CSTVisitor):
"""Root Visitor for Refacto."""
def __init__(self, selected_range: Range) -> None:
self.selected_range = selected_range
60 changes: 60 additions & 0 deletions tests/unit_tests/test_get_node_from_range.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import libcst as cst
from libcst.metadata.parent_node_provider import ParentNodeProvider
from libcst.metadata.position_provider import PositionProvider
from libcst.metadata.scope_provider import Scope
from libcst.metadata.scope_provider import ScopeProvider
from pygls.lsp.types.basic_structures import Position
from pygls.lsp.types.basic_structures import Range


class FindNode(cst.CSTVisitor):
METADATA_DEPENDENCIES = (
PositionProvider,
ScopeProvider,
ParentNodeProvider,
)

def __init__(self, selected_range: Range) -> None:
self.selected_range = selected_range
self.node: cst.CSTNode | None = None
self.parent: cst.CSTNode | None = None
self.scope: Scope | None = None

super().__init__()

def visit_Name(self, node: cst.Name) -> bool:
if self._is_same_starting_position(node=node):
self._set_things(node=node)
return False
return True

def _set_things(self, node: cst.CSTNode) -> None:
self.node = node
self.scope = self.get_metadata(ScopeProvider, node)
try:
self.parent = self.get_metadata(ParentNodeProvider, node)
except KeyError:
# No parent
return

def _is_same_starting_position(self, node: cst.CSTNode) -> bool:
libcst_range = self.get_metadata(PositionProvider, node)
return all(
[
self.selected_range.start.line == libcst_range.start.line - 1,
self.selected_range.start.character == libcst_range.start.column,
],
)


def test_stuff():
selected_range = Range(
start=Position(line=0, character=0),
end=Position(line=0, character=1),
)
with open("tests/test_cases/inline_variable/simplest/before.py", "r") as src:
code = src.read()
module = cst.MetadataWrapper(cst.parse_module(source=code))
visitor = FindNode(selected_range=selected_range)
module.visit(visitor=visitor)
assert visitor.node is not None