diff --git a/approvaltests/inline/split_code.py b/approvaltests/inline/split_code.py index ae6add96..c93d45b9 100644 --- a/approvaltests/inline/split_code.py +++ b/approvaltests/inline/split_code.py @@ -1,3 +1,7 @@ +import re +from enum import Enum + + class SplitCode: def __init__(self, before_method, after_method, tab): self.before_method = before_method @@ -12,34 +16,33 @@ def on_method(code, method_name) -> "SplitCode": lines = code.split("\n") before = [] after = [] - inside_method = False - inside_doc_string = False tab = " " - after_method = False - state = 0 + class State(Enum): + BEFORE = 0 + FIRST_LINE_OF_METHOD_BODY = 1 + IN_DOCSTRING = 2 + AFTER_DOCTSTRING = 3 + state = State.BEFORE for line in lines: stripped_line = line.strip() - - if state == 0: - before.append(line) if stripped_line.startswith(f"def {method_name}("): - state = 1 - continue - if state == 1: - tab = line[: line.find(stripped_line)] + state = State.FIRST_LINE_OF_METHOD_BODY + before.append(line) + elif state == State.BEFORE: + before.append(line) + elif state == State.FIRST_LINE_OF_METHOD_BODY: + tab = re.compile(r'^\s*').match(line).group() if stripped_line.startswith('"""'): - state = 2 - continue + state = State.IN_DOCSTRING else: - state = 3 - if state == 2: + state = State.AFTER_DOCTSTRING + after.append(line) + elif state == State.IN_DOCSTRING: if stripped_line.startswith('"""'): - state = 3 - continue - if state == 3: + state = State.AFTER_DOCTSTRING + elif state == State.AFTER_DOCTSTRING: after.append(line) - return SplitCode("\n".join(before), "\n".join(after), tab) def indent(self, received_text): diff --git a/tests/test_inline_approvals.py b/tests/test_inline_approvals.py index 4655f7cf..cb022f12 100644 --- a/tests/test_inline_approvals.py +++ b/tests/test_inline_approvals.py @@ -47,7 +47,7 @@ def fizz_buzz(param): return return_string -def test_fizz_buzz_to_15(): +def test_fizz_buzz(): """ 1 2 @@ -64,7 +64,7 @@ def test_fizz_buzz_to_15(): def test_docstrings(): """ hello - world + world """ # verify_inline(greetting()) # verify(greetting(), options=Options().inline(show_code= False)) @@ -72,7 +72,7 @@ def test_docstrings(): def greeting(): - return "hello \n world" + return "hello\nworld" class InlineReporter(Reporter):