diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b3006c2bfb..c7a63ab18d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,13 +3,13 @@ repos: rev: 23.3.0 hooks: - id: black - exclude: ^tests/codegen/snapshots/python/ + exclude: ^tests/\w+/snapshots/ - repo: https://github.com/charliermarsh/ruff-pre-commit rev: v0.0.272 hooks: - id: ruff - exclude: ^tests/codegen/snapshots/python/ + exclude: ^tests/\w+/snapshots/ - repo: https://github.com/patrick91/pre-commit-alex rev: aa5da9e54b92ab7284feddeaf52edf14b1690de3 diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000000..fe575df8d1 --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,13 @@ +Release type: minor + +This release introduces a new command called `upgrade`, this command can be used +to run codemods on your codebase to upgrade to the latest version of Strawberry. + +At the moment we only support upgrading unions to use the new syntax with +annotated, but in future we plan to add more commands to help with upgrading. + +Here's how you can use the command to upgrade your codebase: + +```shell +strawberry upgrade annotated-union . +``` diff --git a/docs/README.md b/docs/README.md index 5ae16f3f23..a42315b9ca 100644 --- a/docs/README.md +++ b/docs/README.md @@ -10,6 +10,7 @@ - [Subscriptions](./general/subscriptions.md) - [Why](./general/why.md) - [Breaking changes](./breaking-changes.md) +- [Breaking changes](./general/upgrades.md) - [FAQ](./faq.md) ## Types diff --git a/docs/general/upgrades.md b/docs/general/upgrades.md new file mode 100644 index 0000000000..2b925f3280 --- /dev/null +++ b/docs/general/upgrades.md @@ -0,0 +1,24 @@ +--- +title: Upgrading Strawberry +--- + +# Upgrading Strawberry + + + +We try to keep Strawberry as backwards compatible as possible, but sometimes we +need to make updates to the public API. While we try to deprecate APIs before +removing them, we also want to make it as easy as possible to upgrade to the +latest version of Strawberry. + +For this reason we provide a CLI command that can automatically upgrade your +codebase to use the updated APIs. + +At the moment we only support updating unions to use the new syntax with +annotated, but in future we plan to add more commands to help with upgrading. + +Here's how you can use the command to upgrade your codebase: + +```shell +strawberry upgrade annotated-union . +``` diff --git a/pyproject.toml b/pyproject.toml index 8ba334b048..5105b5ef11 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -139,6 +139,7 @@ strawberry = "strawberry.cli:run" line-length = 88 extend-exclude = ''' tests/codegen/snapshots/ +tests/cli/snapshots/ ''' [tool.pytest.ini_options] @@ -338,6 +339,7 @@ src = ["strawberry", "tests"] "tests/federation/printer/*" = ["E501"] "tests/test_printer/test_basic.py" = ["E501"] "tests/pyright/test_federation.py" = ["E501"] +"tests/codemods/*" = ["E501"] "tests/test_printer/test_schema_directives.py" = ["E501"] "tests/*" = ["RSE102", "SLF001", "TCH001", "TCH002", "TCH003", "ANN001", "ANN201", "PLW0603", "PLC1901", "S603", "S607", "B018"] "strawberry/extensions/tracing/__init__.py" = ["TCH004"] diff --git a/strawberry/cli/__init__.py b/strawberry/cli/__init__.py index e31401b122..9a2fe791a6 100644 --- a/strawberry/cli/__init__.py +++ b/strawberry/cli/__init__.py @@ -1,7 +1,7 @@ -from .commands.codegen import codegen # noqa -from .commands.export_schema import export_schema # noqa -from .commands.server import server # noqa - +from .commands.codegen import codegen as codegen # noqa +from .commands.export_schema import export_schema as export_schema # noqa +from .commands.server import server as server # noqa +from .commands.upgrade import upgrade as upgrade # noqa from .app import app diff --git a/strawberry/cli/commands/upgrade/__init__.py b/strawberry/cli/commands/upgrade/__init__.py new file mode 100644 index 0000000000..57e1f556dd --- /dev/null +++ b/strawberry/cli/commands/upgrade/__init__.py @@ -0,0 +1,75 @@ +from __future__ import annotations + +import glob +import pathlib # noqa: TCH003 +import sys +from typing import List + +import rich +import typer +from libcst.codemod import CodemodContext + +from strawberry.cli.app import app +from strawberry.codemods.annotated_unions import ConvertUnionToAnnotatedUnion + +from ._run_codemod import run_codemod + +codemods = { + "annotated-union": ConvertUnionToAnnotatedUnion, +} + + +# TODO: add support for running all of them +@app.command(help="Upgrades a Strawberry project to the latest version") +def upgrade( + codemod: str = typer.Argument( + ..., + autocompletion=lambda: list(codemods.keys()), + help="Name of the upgrade to run", + ), + paths: List[pathlib.Path] = typer.Argument(file_okay=True, dir_okay=True), + python_target: str = typer.Option( + ".".join(str(x) for x in sys.version_info[:2]), + "--python-target", + help="Python version to target", + ), + use_typing_extensions: bool = typer.Option( + False, + "--use-typing-extensions", + help="Use typing_extensions instead of typing for newer features", + ), +) -> None: + if codemod not in codemods: + rich.print(f'[red]Upgrade named "{codemod}" does not exist') + + raise typer.Exit(2) + + python_target_version = tuple(int(x) for x in python_target.split(".")) + + transformer = ConvertUnionToAnnotatedUnion( + CodemodContext(), + use_pipe_syntax=python_target_version >= (3, 10), + use_typing_extensions=use_typing_extensions, + ) + + files: list[str] = [] + + for path in paths: + if path.is_dir(): + glob_path = str(path / "**/*.py") + files.extend(glob.glob(glob_path, recursive=True)) + else: + files.append(str(path)) + + files = list(set(files)) + + results = list(run_codemod(transformer, files)) + changed = [result for result in results if result.changed] + + rich.print() + rich.print("[green]Upgrade completed successfully, here's a summary:") + rich.print(f" - {len(changed)} files changed") + rich.print(f" - {len(results) - len(changed)} files skipped") + + if changed: + raise typer.Exit(1) diff --git a/strawberry/cli/commands/upgrade/_fake_progress.py b/strawberry/cli/commands/upgrade/_fake_progress.py new file mode 100644 index 0000000000..40d92da5a0 --- /dev/null +++ b/strawberry/cli/commands/upgrade/_fake_progress.py @@ -0,0 +1,21 @@ +from typing import Any + +from rich.progress import TaskID + + +class FakeProgress: + """A fake progress bar that does nothing. + + This is used when the user has only one file to process.""" + + def advance(self, task_id: TaskID) -> None: + pass + + def add_task(self, *args: Any, **kwargs: Any) -> TaskID: + return TaskID(0) + + def __enter__(self) -> "FakeProgress": + return self + + def __exit__(self, *args: Any, **kwargs: Any) -> None: + pass diff --git a/strawberry/cli/commands/upgrade/_run_codemod.py b/strawberry/cli/commands/upgrade/_run_codemod.py new file mode 100644 index 0000000000..d08cbc05e9 --- /dev/null +++ b/strawberry/cli/commands/upgrade/_run_codemod.py @@ -0,0 +1,74 @@ +from __future__ import annotations + +import contextlib +import os +from multiprocessing import Pool, cpu_count +from typing import TYPE_CHECKING, Any, Dict, Generator, Sequence, Type, Union + +from libcst.codemod._cli import ExecutionConfig, ExecutionResult, _execute_transform +from libcst.codemod._dummy_pool import DummyPool +from rich.progress import Progress + +from ._fake_progress import FakeProgress + +if TYPE_CHECKING: + from libcst.codemod import Codemod + +ProgressType = Union[Type[Progress], Type[FakeProgress]] +PoolType = Union[Type[Pool], Type[DummyPool]] # type: ignore + + +def _execute_transform_wrap( + job: Dict[str, Any], +) -> ExecutionResult: + # TODO: maybe capture warnings? + with open(os.devnull, "w") as null: # noqa: PTH123 + with contextlib.redirect_stderr(null): + return _execute_transform(**job) + + +def _get_progress_and_pool( + total_files: int, jobs: int +) -> tuple[PoolType, ProgressType]: + poll_impl: PoolType = Pool # type: ignore + progress_impl: ProgressType = Progress + + if total_files == 1 or jobs == 1: + poll_impl = DummyPool + + if total_files == 1: + progress_impl = FakeProgress + + return poll_impl, progress_impl + + +def run_codemod( + codemod: Codemod, + files: Sequence[str], +) -> Generator[ExecutionResult, None, None]: + chunk_size = 4 + total = len(files) + jobs = min(cpu_count(), (total + chunk_size - 1) // chunk_size) + + config = ExecutionConfig() + + pool_impl, progress_impl = _get_progress_and_pool(total, jobs) + + tasks = [ + { + "transformer": codemod, + "filename": filename, + "config": config, + } + for filename in files + ] + + with pool_impl(processes=jobs) as p, progress_impl() as progress: # type: ignore + task_id = progress.add_task("[cyan]Updating...", total=len(tasks)) + + for result in p.imap_unordered( + _execute_transform_wrap, tasks, chunksize=chunk_size + ): + progress.advance(task_id) + + yield result diff --git a/strawberry/codemods/__init__.py b/strawberry/codemods/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/strawberry/codemods/annotated_unions.py b/strawberry/codemods/annotated_unions.py new file mode 100644 index 0000000000..14ab1efdd7 --- /dev/null +++ b/strawberry/codemods/annotated_unions.py @@ -0,0 +1,185 @@ +from __future__ import annotations + +from typing import Optional, Sequence + +import libcst as cst +import libcst.matchers as m +from libcst._nodes.expression import BaseExpression, Call # noqa: TCH002 +from libcst.codemod import CodemodContext, VisitorBasedCodemodCommand +from libcst.codemod.visitors import AddImportsVisitor, RemoveImportsVisitor + + +def _find_named_argument(args: Sequence[cst.Arg], name: str) -> cst.Arg | None: + return next( + (arg for arg in args if arg.keyword and arg.keyword.value == name), + None, + ) + + +def _find_positional_argument( + args: Sequence[cst.Arg], search_index: int +) -> cst.Arg | None: + for index, arg in enumerate(args): + if index > search_index: + return None + + if index == search_index and arg.keyword is None: + return arg + + return None + + +class ConvertUnionToAnnotatedUnion(VisitorBasedCodemodCommand): + DESCRIPTION: str = ( + "Converts strawberry.union(..., types=(...)) to " + "Annotated[Union[...], strawberry.union(...)]" + ) + + def __init__( + self, + context: CodemodContext, + use_pipe_syntax: bool = True, + use_typing_extensions: bool = False, + ) -> None: + self._is_using_named_import = False + self.use_pipe_syntax = use_pipe_syntax + self.use_typing_extensions = use_typing_extensions + + super().__init__(context) + + def visit_Module(self, node: cst.Module) -> Optional[bool]: + self._is_using_named_import = False + + return super().visit_Module(node) + + @m.visit( + m.ImportFrom( + m.Name("strawberry"), + [ + m.ZeroOrMore(), + m.ImportAlias(m.Name("union")), + m.ZeroOrMore(), + ], + ) + ) + def visit_import_from(self, original_node: cst.ImportFrom) -> None: + self._is_using_named_import = True + + @m.leave( + m.Call( + func=m.Attribute(value=m.Name("strawberry"), attr=m.Name("union")) + | m.Name("union") + ) + ) + def leave_union_call( + self, original_node: Call, updated_node: Call + ) -> BaseExpression: + if not self._is_using_named_import and isinstance(original_node.func, cst.Name): + return original_node + + types = _find_named_argument(original_node.args, "types") + union_name = _find_named_argument(original_node.args, "name") + + if types is None: + types = _find_positional_argument(original_node.args, 1) + + # this is probably a strawberry.union(name="...") so we skip the conversion + # as it is going to be used in the new way already 😊 + + if types is None: + return original_node + + AddImportsVisitor.add_needed_import( + self.context, + "typing_extensions" if self.use_typing_extensions else "typing", + "Annotated", + ) + + RemoveImportsVisitor.remove_unused_import(self.context, "strawberry", "union") + + if union_name is None: + union_name = _find_positional_argument(original_node.args, 0) + + assert union_name + assert isinstance(types.value, (cst.Tuple, cst.List)) + + types = types.value.elements # type: ignore + union_name = union_name.value # type: ignore + + description = _find_named_argument(original_node.args, "description") + directives = _find_named_argument(original_node.args, "directives") + + if self.use_pipe_syntax: + union_node = self._create_union_node_with_pipe_syntax(types) # type: ignore + else: + AddImportsVisitor.add_needed_import(self.context, "typing", "Union") + + union_node = cst.Subscript( + value=cst.Name(value="Union"), + slice=[ + cst.SubscriptElement(slice=cst.Index(value=t.value)) for t in types # type: ignore # noqa: E501 + ], + ) + + union_call_args = [ + cst.Arg( + value=union_name, # type: ignore + keyword=cst.Name(value="name"), + equal=cst.AssignEqual( + whitespace_before=cst.SimpleWhitespace(""), + whitespace_after=cst.SimpleWhitespace(""), + ), + ) + ] + + additional_args = {"description": description, "directives": directives} + + union_call_args.extend( + cst.Arg( + value=arg.value, + keyword=cst.Name(name), + equal=cst.AssignEqual( + whitespace_before=cst.SimpleWhitespace(""), + whitespace_after=cst.SimpleWhitespace(""), + ), + ) + for name, arg in additional_args.items() + if arg is not None + ) + + union_call_node = cst.Call( + func=cst.Attribute( + value=cst.Name(value="strawberry"), + attr=cst.Name(value="union"), + ), + args=union_call_args, + ) + + return cst.Subscript( + value=cst.Name(value="Annotated"), + slice=[ + cst.SubscriptElement( + slice=cst.Index( + value=union_node, + ), + ), + cst.SubscriptElement( + slice=cst.Index( + value=union_call_node, + ), + ), + ], + ) + + @classmethod + def _create_union_node_with_pipe_syntax( + cls, types: Sequence[cst.BaseElement] + ) -> cst.BaseExpression: + type_names = [t.value for t in types] + + if not all(isinstance(t, cst.Name) for t in type_names): + raise ValueError("Only names are supported for now") + + expression = " | ".join(name.value for name in type_names) # type: ignore + + return cst.parse_expression(expression) diff --git a/tests/cli/fixtures/unions.py b/tests/cli/fixtures/unions.py new file mode 100644 index 0000000000..b3559873db --- /dev/null +++ b/tests/cli/fixtures/unions.py @@ -0,0 +1,29 @@ +import strawberry + +# create a few types and then a union type + + +@strawberry.type +class Foo: + a: str + + +@strawberry.type +class Bar: + b: str + + +@strawberry.type +class Baz: + c: str + + +@strawberry.type +class Qux: + d: str + + +# this is the union type + +Union1 = strawberry.union(name="Union1", types=(Foo, Bar, Baz, Qux)) +Union2 = strawberry.union(name="Union2", types=(Baz, Qux)) diff --git a/tests/cli/snapshots/unions.py b/tests/cli/snapshots/unions.py new file mode 100644 index 0000000000..2f03e0bb4b --- /dev/null +++ b/tests/cli/snapshots/unions.py @@ -0,0 +1,30 @@ +import strawberry +from typing import Annotated + +# create a few types and then a union type + + +@strawberry.type +class Foo: + a: str + + +@strawberry.type +class Bar: + b: str + + +@strawberry.type +class Baz: + c: str + + +@strawberry.type +class Qux: + d: str + + +# this is the union type + +Union1 = Annotated[Foo | Bar | Baz | Qux, strawberry.union(name="Union1")] +Union2 = Annotated[Baz | Qux, strawberry.union(name="Union2")] diff --git a/tests/cli/snapshots/unions_py38.py b/tests/cli/snapshots/unions_py38.py new file mode 100644 index 0000000000..4fd0143c69 --- /dev/null +++ b/tests/cli/snapshots/unions_py38.py @@ -0,0 +1,30 @@ +import strawberry +from typing import Annotated, Union + +# create a few types and then a union type + + +@strawberry.type +class Foo: + a: str + + +@strawberry.type +class Bar: + b: str + + +@strawberry.type +class Baz: + c: str + + +@strawberry.type +class Qux: + d: str + + +# this is the union type + +Union1 = Annotated[Union[Foo, Bar, Baz, Qux], strawberry.union(name="Union1")] +Union2 = Annotated[Union[Baz, Qux], strawberry.union(name="Union2")] diff --git a/tests/cli/snapshots/unions_typing_extension.py b/tests/cli/snapshots/unions_typing_extension.py new file mode 100644 index 0000000000..11c5f01cfb --- /dev/null +++ b/tests/cli/snapshots/unions_typing_extension.py @@ -0,0 +1,30 @@ +import strawberry +from typing_extensions import Annotated + +# create a few types and then a union type + + +@strawberry.type +class Foo: + a: str + + +@strawberry.type +class Bar: + b: str + + +@strawberry.type +class Baz: + c: str + + +@strawberry.type +class Qux: + d: str + + +# this is the union type + +Union1 = Annotated[Foo | Bar | Baz | Qux, strawberry.union(name="Union1")] +Union2 = Annotated[Baz | Qux, strawberry.union(name="Union2")] diff --git a/tests/cli/test_upgrade.py b/tests/cli/test_upgrade.py new file mode 100644 index 0000000000..b449953ca6 --- /dev/null +++ b/tests/cli/test_upgrade.py @@ -0,0 +1,85 @@ +from pathlib import Path + +from pytest_snapshot.plugin import Snapshot +from typer.testing import CliRunner + +from strawberry.cli.app import app + +HERE = Path(__file__).parent + + +def test_upgrade_returns_error_code_if_codemod_does_not_exist(cli_runner: CliRunner): + result = cli_runner.invoke( + app, + ["upgrade", "a_random_codemod", "."], + ) + + assert result.exit_code == 2 + assert 'Upgrade named "a_random_codemod" does not exist' in result.stdout + + +def test_upgrade_works_annotated_unions( + cli_runner: CliRunner, tmp_path: Path, snapshot: Snapshot +): + source = HERE / "fixtures/unions.py" + + target = tmp_path / "unions.py" + target.write_text(source.read_text()) + + result = cli_runner.invoke( + app, + ["upgrade", "--python-target", "3.11", "annotated-union", str(target)], + ) + + assert result.exit_code == 1 + assert "1 files changed\n - 0 files skipped" in result.stdout + + snapshot.snapshot_dir = HERE / "snapshots" + snapshot.assert_match(target.read_text(), "unions.py") + + +def test_upgrade_works_annotated_unions_target_python( + cli_runner: CliRunner, tmp_path: Path, snapshot: Snapshot +): + source = HERE / "fixtures/unions.py" + + target = tmp_path / "unions.py" + target.write_text(source.read_text()) + + result = cli_runner.invoke( + app, + ["upgrade", "--python-target", "3.8", "annotated-union", str(target)], + ) + + assert result.exit_code == 1 + assert "1 files changed\n - 0 files skipped" in result.stdout + + snapshot.snapshot_dir = HERE / "snapshots" + snapshot.assert_match(target.read_text(), "unions_py38.py") + + +def test_upgrade_works_annotated_unions_typing_extensions( + cli_runner: CliRunner, tmp_path: Path, snapshot: Snapshot +): + source = HERE / "fixtures/unions.py" + + target = tmp_path / "unions.py" + target.write_text(source.read_text()) + + result = cli_runner.invoke( + app, + [ + "upgrade", + "--use-typing-extensions", + "--python-target", + "3.11", + "annotated-union", + str(target), + ], + ) + + assert result.exit_code == 1 + assert "1 files changed\n - 0 files skipped" in result.stdout + + snapshot.snapshot_dir = HERE / "snapshots" + snapshot.assert_match(target.read_text(), "unions_typing_extension.py") diff --git a/tests/codemods/__init__.py b/tests/codemods/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/codemods/test_annotated_unions.py b/tests/codemods/test_annotated_unions.py new file mode 100644 index 0000000000..af86675219 --- /dev/null +++ b/tests/codemods/test_annotated_unions.py @@ -0,0 +1,146 @@ +from libcst.codemod import CodemodTest + +from strawberry.codemods.annotated_unions import ConvertUnionToAnnotatedUnion + + +class TestConvertConstantCommand(CodemodTest): + TRANSFORM = ConvertUnionToAnnotatedUnion + + def test_update_union(self) -> None: + before = """ + AUnion = strawberry.union(name="ABC", types=(Foo, Bar)) + """ + + after = """ + from typing import Annotated, Union + + AUnion = Annotated[Union[Foo, Bar], strawberry.union(name="ABC")] + """ + + self.assertCodemod( + before, after, use_pipe_syntax=False, use_typing_extensions=False + ) + + def test_update_union_typing_extensions(self) -> None: + before = """ + AUnion = strawberry.union(name="ABC", types=(Foo, Bar)) + """ + + after = """ + from typing import Annotated, Union + + AUnion = Annotated[Union[Foo, Bar], strawberry.union(name="ABC")] + """ + + self.assertCodemod(before, after, use_pipe_syntax=False) + + def test_update_union_using_import(self) -> None: + before = """ + from strawberry import union + + AUnion = union(name="ABC", types=(Foo, Bar)) + """ + + after = """ + from typing import Annotated, Union + + AUnion = Annotated[Union[Foo, Bar], strawberry.union(name="ABC")] + """ + + self.assertCodemod(before, after, use_pipe_syntax=False) + + def test_noop_other_union(self) -> None: + before = """ + from potato import union + + union("A", "B") + """ + + after = """ + from potato import union + + union("A", "B") + """ + + self.assertCodemod(before, after, use_pipe_syntax=False) + + def test_update_union_positional_name(self) -> None: + before = """ + AUnion = strawberry.union("ABC", types=(Foo, Bar)) + """ + + after = """ + from typing import Annotated, Union + + AUnion = Annotated[Union[Foo, Bar], strawberry.union(name="ABC")] + """ + + self.assertCodemod(before, after, use_pipe_syntax=False) + + def test_update_swapped_kwargs(self) -> None: + before = """ + AUnion = strawberry.union(types=(Foo, Bar), name="ABC") + """ + + after = """ + from typing import Annotated, Union + + AUnion = Annotated[Union[Foo, Bar], strawberry.union(name="ABC")] + """ + + self.assertCodemod(before, after, use_pipe_syntax=False) + + def test_update_union_list(self) -> None: + before = """ + AUnion = strawberry.union(name="ABC", types=[Foo, Bar]) + """ + + after = """ + from typing import Annotated, Union + + AUnion = Annotated[Union[Foo, Bar], strawberry.union(name="ABC")] + """ + + self.assertCodemod(before, after, use_pipe_syntax=False) + + def test_update_positional_arguments(self) -> None: + before = """ + AUnion = strawberry.union("ABC", (Foo, Bar)) + """ + + after = """ + from typing import Annotated, Union + + AUnion = Annotated[Union[Foo, Bar], strawberry.union(name="ABC")] + """ + + self.assertCodemod(before, after, use_pipe_syntax=False) + + def test_supports_directives_and_description(self) -> None: + before = """ + AUnion = strawberry.union( + "ABC", + (Foo, Bar), + description="cool union", + directives=[object()], + ) + """ + + after = """ + from typing import Annotated, Union + + AUnion = Annotated[Union[Foo, Bar], strawberry.union(name="ABC", description="cool union", directives=[object()])] + """ + + self.assertCodemod(before, after, use_pipe_syntax=False) + + def test_noop_with_annotated_unions(self) -> None: + before = """ + AUnion = Annotated[Union[Foo, Bar], strawberry.union(name="ABC")] + """ + + after = """ + AUnion = Annotated[Union[Foo, Bar], strawberry.union(name="ABC")] + """ + + self.assertCodemod(before, after, use_pipe_syntax=False) diff --git a/tests/codemods/test_annotated_unions_pipe.py b/tests/codemods/test_annotated_unions_pipe.py new file mode 100644 index 0000000000..c5705720c2 --- /dev/null +++ b/tests/codemods/test_annotated_unions_pipe.py @@ -0,0 +1,131 @@ +from libcst.codemod import CodemodTest + +from strawberry.codemods.annotated_unions import ConvertUnionToAnnotatedUnion + + +class TestConvertConstantCommand(CodemodTest): + TRANSFORM = ConvertUnionToAnnotatedUnion + + def test_update_union(self) -> None: + before = """ + AUnion = strawberry.union(name="ABC", types=(Foo, Bar)) + """ + + after = """ + from typing import Annotated + + AUnion = Annotated[Foo | Bar, strawberry.union(name="ABC")] + """ + + self.assertCodemod(before, after, use_pipe_syntax=True) + + def test_update_union_using_import(self) -> None: + before = """ + from strawberry import union + + AUnion = union(name="ABC", types=(Foo, Bar)) + """ + + after = """ + from typing import Annotated + + AUnion = Annotated[Foo | Bar, strawberry.union(name="ABC")] + """ + + self.assertCodemod(before, after, use_pipe_syntax=True) + + def test_noop_other_union(self) -> None: + before = """ + from potato import union + + union("A", "B") + """ + + after = """ + from potato import union + + union("A", "B") + """ + + self.assertCodemod(before, after, use_pipe_syntax=True) + + def test_update_union_positional_name(self) -> None: + before = """ + AUnion = strawberry.union("ABC", types=(Foo, Bar)) + """ + + after = """ + from typing import Annotated + + AUnion = Annotated[Foo | Bar, strawberry.union(name="ABC")] + """ + + self.assertCodemod(before, after, use_pipe_syntax=True) + + def test_update_swapped_kwargs(self) -> None: + before = """ + AUnion = strawberry.union(types=(Foo, Bar), name="ABC") + """ + + after = """ + from typing import Annotated + + AUnion = Annotated[Foo | Bar, strawberry.union(name="ABC")] + """ + + self.assertCodemod(before, after, use_pipe_syntax=True) + + def test_update_union_list(self) -> None: + before = """ + AUnion = strawberry.union(name="ABC", types=[Foo, Bar]) + """ + + after = """ + from typing import Annotated + + AUnion = Annotated[Foo | Bar, strawberry.union(name="ABC")] + """ + + self.assertCodemod(before, after, use_pipe_syntax=True) + + def test_update_positional_arguments(self) -> None: + before = """ + AUnion = strawberry.union("ABC", (Foo, Bar)) + """ + + after = """ + from typing import Annotated + + AUnion = Annotated[Foo | Bar, strawberry.union(name="ABC")] + """ + + self.assertCodemod(before, after, use_pipe_syntax=True) + + def test_supports_directives_and_description(self) -> None: + before = """ + AUnion = strawberry.union( + "ABC", + (Foo, Bar), + description="cool union", + directives=[object()], + ) + """ + + after = """ + from typing import Annotated + + AUnion = Annotated[Foo | Bar, strawberry.union(name="ABC", description="cool union", directives=[object()])] + """ + + self.assertCodemod(before, after, use_pipe_syntax=True) + + def test_noop_with_annotated_unions(self) -> None: + before = """ + AUnion = Annotated[Foo | Bar, strawberry.union(name="ABC")] + """ + + after = """ + AUnion = Annotated[Foo | Bar, strawberry.union(name="ABC")] + """ + + self.assertCodemod(before, after, use_pipe_syntax=True)