diff --git a/strawberry/cli/commands/codegen.py b/strawberry/cli/commands/codegen.py index 5e434c0071..533a649b4c 100644 --- a/strawberry/cli/commands/codegen.py +++ b/strawberry/cli/commands/codegen.py @@ -1,8 +1,11 @@ from __future__ import annotations +import contextlib +import functools import importlib import inspect -from pathlib import Path # noqa: TCH003 +import warnings +from pathlib import Path from typing import TYPE_CHECKING, List, Optional, Type import rich @@ -11,6 +14,8 @@ from strawberry.cli.app import app from strawberry.cli.utils import load_schema from strawberry.codegen import QueryCodegen, QueryCodegenPlugin +from strawberry.codegen.plugins.python import PythonPlugin +from strawberry.codegen.plugins.typescript import TypeScriptPlugin if TYPE_CHECKING: from strawberry.codegen import CodegenResult @@ -62,6 +67,7 @@ def _import_plugin(plugin: str) -> Optional[Type[QueryCodegenPlugin]]: return None +@functools.lru_cache() def _load_plugin(plugin_path: str) -> Type[QueryCodegenPlugin]: # try to import plugin_name from current folder # then try to import from strawberry.codegen.plugins @@ -78,8 +84,26 @@ def _load_plugin(plugin_path: str) -> Type[QueryCodegenPlugin]: return plugin -def _load_plugins(plugins: List[str]) -> List[QueryCodegenPlugin]: - return [_load_plugin(plugin)() for plugin in plugins] +def _load_plugins(plugin_ids: List[str], query: Path) -> List[QueryCodegenPlugin]: + plugins = [] + for ptype_id in plugin_ids: + ptype = _load_plugin(ptype_id) + # inject arguments into the plugin's `__init__` based on + # the names in the signature (e.g., similar to flake8 plugins). + sig = inspect.signature(ptype) + kwargs = {} + if "query" in sig.parameters: + kwargs["query"] = query + if "query_file" in sig.parameters: + kwargs["query_file"] = query + + plugin = ptype(**kwargs) + if not hasattr(plugin, "query"): + plugin.query = query + + plugins.append(plugin) + + return plugins class ConsolePlugin(QueryCodegenPlugin): @@ -89,13 +113,20 @@ def __init__( self.query = query self.output_dir = output_dir self.plugins = plugins + self.files_generated: List[Path] = [] - def on_start(self) -> None: + def before_any_start(self) -> None: rich.print( "[bold yellow]The codegen is experimental. Please submit any bug at " "https://github.com/strawberry-graphql/strawberry\n", ) + def after_all_finished(self) -> None: + rich.print("[green]Generated:") + for fname in self.files_generated: + rich.print(f" {fname}") + + def on_start(self) -> None: plugin_names = [plugin.__class__.__name__ for plugin in self.plugins] rich.print( @@ -107,6 +138,8 @@ def on_end(self, result: CodegenResult) -> None: self.output_dir.mkdir(parents=True, exist_ok=True) result.write(self.output_dir) + self.files_generated.extend(Path(cf.path) for cf in result.files) + rich.print( f"[green] Generated {len(result.files)} files in {self.output_dir}", ) @@ -114,7 +147,9 @@ def on_end(self, result: CodegenResult) -> None: @app.command(help="Generate code from a query") def codegen( - query: Path = typer.Argument(..., exists=True, dir_okay=False), + query: Optional[List[Path]] = typer.Argument( + default=None, exists=True, dir_okay=False + ), schema: str = typer.Option(..., help="Python path to the schema file"), app_dir: str = typer.Option( ".", @@ -143,12 +178,50 @@ def codegen( ), cli_plugin: Optional[str] = None, ) -> None: + if not query: + return + schema_symbol = load_schema(schema, app_dir) - console_plugin = _load_plugin(cli_plugin) if cli_plugin else ConsolePlugin + console_plugin_type = _load_plugin(cli_plugin) if cli_plugin else ConsolePlugin - plugins = _load_plugins(selected_plugins) - plugins.append(console_plugin(query, output_dir, plugins)) + plugins = _load_plugins(selected_plugins, query[0]) + console_plugin = console_plugin_type(query[0], output_dir, plugins) + + if not isinstance(console_plugin, ConsolePlugin): + warnings.warn( + "The ConsolePlugin should inherit from ``{__name__}.ConsolePlugin``.", + DeprecationWarning, + stacklevel=1, + ) - code_generator = QueryCodegen(schema_symbol, plugins=plugins) - code_generator.run(query.read_text()) + with contextlib.suppress(AttributeError): + # This method was not part of the original implementation of the ConsolePlugin + # and the ConsolePlugin is overridable on the CLI. It isn't guaranteed that + # custom ConsolePlugin implementations inherit from ``ConsolePlugin``, so if + # the implementor does not have this method, we will just keep going. + console_plugin.before_any_start() + + for q in query: + plugins = _load_plugins(selected_plugins, q) + console_plugin.query = q # update the query in the console plugin. + plugins.append(console_plugin) + + for p in plugins: + # Adjust the names of the output files for the buildin plugins if there + # are multiple files being generated. + if len(query) > 1: + if isinstance(p, PythonPlugin): + p.outfile_name = q.stem + ".py" + elif isinstance(p, TypeScriptPlugin): + p.outfile_name = q.stem + ".ts" + + code_generator = QueryCodegen(schema_symbol, plugins=plugins) + code_generator.run(q.read_text()) + + with contextlib.suppress(AttributeError): + # This method was not part of the original implementation of the ConsolePlugin + # and the ConsolePlugin is overridable on the CLI. It isn't guaranteed that + # custom ConsolePlugin implementations inherit from ``ConsolePlugin``, so if + # the implementor does not have this method, we will just keep going. + console_plugin.after_all_finished() diff --git a/strawberry/codegen/plugins/python.py b/strawberry/codegen/plugins/python.py index 0af167a5f1..5eb4ee5c38 100644 --- a/strawberry/codegen/plugins/python.py +++ b/strawberry/codegen/plugins/python.py @@ -25,6 +25,9 @@ ) +DEFAULT_OUTFILE_NAME = "types.py" + + @dataclass class PythonType: type: str @@ -47,6 +50,7 @@ class PythonPlugin(QueryCodegenPlugin): def __init__(self) -> None: self.imports: Dict[str, Set[str]] = defaultdict(set) + self.outfile_name: str = DEFAULT_OUTFILE_NAME def generate_code( self, types: List[GraphQLType], operation: GraphQLOperation @@ -56,7 +60,7 @@ def generate_code( code = imports + "\n\n" + "\n\n".join(printed_types) - return [CodegenFile("types.py", code.strip())] + return [CodegenFile(self.outfile_name, code.strip())] def _print_imports(self) -> str: imports = [ diff --git a/strawberry/codegen/plugins/typescript.py b/strawberry/codegen/plugins/typescript.py index 5d7e878f0b..03aedd3674 100644 --- a/strawberry/codegen/plugins/typescript.py +++ b/strawberry/codegen/plugins/typescript.py @@ -17,6 +17,9 @@ from strawberry.codegen.types import GraphQLField, GraphQLOperation, GraphQLType +DEFAULT_OUTFILE_NAME = "types.ts" + + class TypeScriptPlugin(QueryCodegenPlugin): SCALARS_TO_TS_TYPE = { "ID": "string", @@ -33,12 +36,15 @@ class TypeScriptPlugin(QueryCodegenPlugin): float: "number", } + def __init__(self) -> None: + self.outfile_name: str = DEFAULT_OUTFILE_NAME + def generate_code( self, types: List[GraphQLType], operation: GraphQLOperation ) -> List[CodegenFile]: printed_types = list(filter(None, (self._print_type(type) for type in types))) - return [CodegenFile("types.ts", "\n\n".join(printed_types))] + return [CodegenFile(self.outfile_name, "\n\n".join(printed_types))] def _get_type_name(self, type_: GraphQLType) -> str: if isinstance(type_, GraphQLOptional): diff --git a/strawberry/codegen/query_codegen.py b/strawberry/codegen/query_codegen.py index 7b9df2b756..e66757cfcd 100644 --- a/strawberry/codegen/query_codegen.py +++ b/strawberry/codegen/query_codegen.py @@ -130,6 +130,12 @@ class HasSelectionSet(Protocol): class QueryCodegenPlugin: + #: That path that holds the query being processed. + #: This gets passed to __init__ if there is a parameter named ``query`` or + #: ``query_file``. Otherwise, it gets set after ``__init__`` for backward + #: compatibility. + query: Path + def on_start(self) -> None: ... diff --git a/tests/cli/test_codegen.py b/tests/cli/test_codegen.py index f63fef07cd..37830d9f65 100644 --- a/tests/cli/test_codegen.py +++ b/tests/cli/test_codegen.py @@ -17,6 +17,20 @@ def on_end(self, result: CodegenResult): return super().on_end(result) +class ConsoleTestPluginWithoutTypicalBase(QueryCodegenPlugin): + def __init__( + self, query: Path, output_dir: Path, plugins: List[QueryCodegenPlugin] + ): + # Delegate to a `ConsoleTestPlugin` to mimic the behavior of some custom + # logic. + self._delegate = ConsoleTestPlugin(query, output_dir, plugins) + + def on_end(self, result: CodegenResult): + result.files[0].path = "renamed.py" + + return self._delegate.on_end(result) + + class QueryCodegenTestPlugin(QueryCodegenPlugin): def generate_code( self, types: List[GraphQLType], operation: GraphQLOperation @@ -29,6 +43,21 @@ def generate_code( ] +class QueryCodegenTestPluginWithInjectableQuery(QueryCodegenPlugin): + def __init__(self, query: Path, query_file: Path) -> None: + assert query == query_file + + def generate_code( + self, types: List[GraphQLType], operation: GraphQLOperation + ) -> List[CodegenFile]: + return [ + CodegenFile( + path="test.py", + content=f"# This is a test file for {operation.name}", + ) + ] + + class EmptyPlugin(QueryCodegenPlugin): def generate_code( self, types: List[GraphQLType], operation: GraphQLOperation @@ -56,6 +85,21 @@ def query_file_path(tmp_path: Path) -> Path: return output_path +@pytest.fixture +def query_file_path2(tmp_path: Path) -> Path: + output_path = tmp_path / "query2.graphql" + output_path.write_text( + """ + query GetUser { + user { + name + } + } + """ + ) + return output_path + + def test_codegen(cli_runner: CliRunner, query_file_path: Path, tmp_path: Path): selector = "tests.fixtures.sample_package.sample_module:schema" result = cli_runner.invoke( @@ -80,6 +124,87 @@ def test_codegen(cli_runner: CliRunner, query_file_path: Path, tmp_path: Path): assert code_path.read_text() == "# This is a test file for GetUser" +def test_codegen_injection_parameters( + cli_runner: CliRunner, query_file_path: Path, tmp_path: Path +): + selector = "tests.fixtures.sample_package.sample_module:schema" + result = cli_runner.invoke( + app, + [ + "codegen", + "-p", + "tests.cli.test_codegen:QueryCodegenTestPluginWithInjectableQuery", + "-o", + str(tmp_path), + "--schema", + selector, + str(query_file_path), + ], + ) + + assert result.exit_code == 0 + + code_path = tmp_path / "test.py" + + assert code_path.exists() + assert code_path.read_text() == "# This is a test file for GetUser" + + +def test_codegen_multiple_files( + cli_runner: CliRunner, query_file_path: Path, query_file_path2: Path, tmp_path: Path +): + expected_paths = [ + tmp_path / "query.py", + tmp_path / "query2.py", + tmp_path / "query.ts", + tmp_path / "query2.ts", + ] + for path in expected_paths: + assert not path.exists() + + selector = "tests.fixtures.sample_package.sample_module:schema" + result = cli_runner.invoke( + app, + [ + "codegen", + "-p", + "python", + "-p", + "typescript", + "-o", + str(tmp_path), + "--schema", + selector, + str(query_file_path), + str(query_file_path2), + ], + ) + + assert result.exit_code == 0 + + for path in expected_paths: + assert path.exists() + assert " GetUserResult" in path.read_text() + + +def test_codegen_pass_no_query(cli_runner: CliRunner, tmp_path: Path): + selector = "tests.fixtures.sample_package.sample_module:schema" + result = cli_runner.invoke( + app, + [ + "codegen", + "-p", + "tests.cli.test_codegen:EmptyPlugin", + "-o", + str(tmp_path), + "--schema", + selector, + ], + ) + + assert result.exit_code == 0 + + def test_codegen_passing_plugin_symbol( cli_runner: CliRunner, query_file_path: Path, tmp_path: Path ): @@ -226,3 +351,31 @@ def test_can_use_custom_cli_plugin( assert code_path.exists() assert "class GetUserResult" in code_path.read_text() + + +def test_can_use_custom_cli_plugin_with_regular_codegen_plugin_inheritance( + cli_runner: CliRunner, query_file_path: Path, tmp_path: Path +): + selector = "tests.fixtures.sample_package.sample_module:schema" + result = cli_runner.invoke( + app, + [ + "codegen", + "--cli-plugin", + "tests.cli.test_codegen:ConsoleTestPluginWithoutTypicalBase", + "-p", + "python", + "--schema", + selector, + "-o", + str(tmp_path), + str(query_file_path), + ], + ) + + assert result.exit_code == 0 + + code_path = tmp_path / "renamed.py" + + assert code_path.exists() + assert "class GetUserResult" in code_path.read_text()