diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000000..8f4c435209 --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,22 @@ +Release type: minor + +`strawberry codegen` can now operate on multiple input query files. +The previous behavior of naming the file `types.js` and `types.py` +for the builtin `typescript` and `python` plugins respectively is +preserved, but only if a single query file is passed. When more +than one query file is passed, the code generator will now use +the stem of the query file's name to construct the name of the +output files. e.g. `my_query.graphql` -> `my_query.js` or +`my_query.py`. Creators of custom plugins are responsible +for controlling the name of the output file themselves. To +accomodate this, if the `__init__` method of a `QueryCodegenPlugin` +has a parameter named `query` or `query_file`, the `pathlib.Path` +to the query file will be passed to the plugin's `__init__` +method. + +Finally, the `ConsolePlugin` has also recieved two new lifecycle +methods. Unlike other `QueryCodegenPlugin`, the same instance of +the `ConsolePlugin` is used for each query file processed. This +allows it to keep state around how many total files were processed. +The `ConsolePlugin` recieved two new lifecycle hooks: `before_any_start` +and `after_all_finished` that get called at the appropriate times. diff --git a/docs/codegen/query-codegen.md b/docs/codegen/query-codegen.md index a9abef13e0..03974ca0db 100644 --- a/docs/codegen/query-codegen.md +++ b/docs/codegen/query-codegen.md @@ -73,7 +73,7 @@ With the following command: strawberry codegen --schema schema --output-dir ./output -p python query.graphql ``` -We'll get the following output inside `output/types.py`: +We'll get the following output inside `output/query.py`: ```python class MyQueryResultUserPost: @@ -119,6 +119,13 @@ from strawberry.codegen.types import GraphQLType, GraphQLOperation class QueryCodegenPlugin: + def __init__(self, query: Path) -> None: + """Initialize the plugin. + + The singular argument is the path to the file that is being processed by this plugin. + """ + self.query = query + def on_start(self) -> None: ... @@ -137,3 +144,36 @@ class QueryCodegenPlugin: - `generated_code` is called when the codegen starts and it receives the types and the operation. You cans use this to generate code for each type and operation. + +### Console plugin + +There is also a plugin that helps to orchestrate the codegen process and notify the +user about what the current codegen process is doing. + +The interface for the ConsolePlugin looks like: + +```python +class ConsolePlugin: + def __init__(self, output_dir: Path): + """Initialize the plugin and tell it where the output should be written.""" + ... + + def before_any_start(self) -> None: + """This method is called before any plugins have been invoked or any queries have been processed.""" + ... + + def after_all_finished(self) -> None: + """This method is called after the full code generation is complete. + + It can be used to report on all the things that have happened during the codegen. + """ + ... + + def on_start(self, plugins: Iterable[QueryCodegenPlugin], query: Path) -> None: + """This method is called before any of the individual plugins have been started.""" + ... + + def on_end(self, result: CodegenResult) -> None: + """This method typically persists the results from a single query to the output directory.""" + ... +``` diff --git a/strawberry/cli/commands/codegen.py b/strawberry/cli/commands/codegen.py index 5e434c0071..6e398a54f0 100644 --- a/strawberry/cli/commands/codegen.py +++ b/strawberry/cli/commands/codegen.py @@ -1,25 +1,23 @@ from __future__ import annotations +import functools import importlib import inspect from pathlib import Path # noqa: TCH003 -from typing import TYPE_CHECKING, List, Optional, Type +from typing import List, Optional, Type import rich import typer from strawberry.cli.app import app from strawberry.cli.utils import load_schema -from strawberry.codegen import QueryCodegen, QueryCodegenPlugin - -if TYPE_CHECKING: - from strawberry.codegen import CodegenResult +from strawberry.codegen import ConsolePlugin, QueryCodegen, QueryCodegenPlugin def _is_codegen_plugin(obj: object) -> bool: return ( inspect.isclass(obj) - and issubclass(obj, QueryCodegenPlugin) + and issubclass(obj, (QueryCodegenPlugin, ConsolePlugin)) and obj is not QueryCodegenPlugin ) @@ -62,6 +60,7 @@ def _import_plugin(plugin: str) -> Optional[Type[QueryCodegenPlugin]]: return None +@functools.lru_cache def _load_plugin(plugin_path: str) -> Type[QueryCodegenPlugin]: # try to import plugin_name from current folder # then try to import from strawberry.codegen.plugins @@ -78,43 +77,21 @@ def _load_plugin(plugin_path: str) -> Type[QueryCodegenPlugin]: return plugin -def _load_plugins(plugins: List[str]) -> List[QueryCodegenPlugin]: - return [_load_plugin(plugin)() for plugin in plugins] - - -class ConsolePlugin(QueryCodegenPlugin): - def __init__( - self, query: Path, output_dir: Path, plugins: List[QueryCodegenPlugin] - ): - self.query = query - self.output_dir = output_dir - self.plugins = plugins - - def on_start(self) -> None: - rich.print( - "[bold yellow]The codegen is experimental. Please submit any bug at " - "https://github.com/strawberry-graphql/strawberry\n", - ) - - plugin_names = [plugin.__class__.__name__ for plugin in self.plugins] - - rich.print( - f"[green]Generating code for {self.query} using " - f"{', '.join(plugin_names)} plugin(s)", - ) - - def on_end(self, result: CodegenResult) -> None: - self.output_dir.mkdir(parents=True, exist_ok=True) - result.write(self.output_dir) +def _load_plugins(plugin_ids: List[str], query: Path) -> List[QueryCodegenPlugin]: + plugins = [] + for ptype_id in plugin_ids: + ptype = _load_plugin(ptype_id) + plugin = ptype(query) + plugins.append(plugin) - rich.print( - f"[green] Generated {len(result.files)} files in {self.output_dir}", - ) + return plugins @app.command(help="Generate code from a query") def codegen( - query: Path = typer.Argument(..., exists=True, dir_okay=False), + query: Optional[List[Path]] = typer.Argument( + default=None, exists=True, dir_okay=False + ), schema: str = typer.Option(..., help="Python path to the schema file"), app_dir: str = typer.Option( ".", @@ -143,12 +120,22 @@ def codegen( ), cli_plugin: Optional[str] = None, ) -> None: + if not query: + return + schema_symbol = load_schema(schema, app_dir) - console_plugin = _load_plugin(cli_plugin) if cli_plugin else ConsolePlugin + console_plugin_type = _load_plugin(cli_plugin) if cli_plugin else ConsolePlugin + console_plugin = console_plugin_type(output_dir) + console_plugin.before_any_start() + + for q in query: + plugins = _load_plugins(selected_plugins, q) + console_plugin.query = q # update the query in the console plugin. - plugins = _load_plugins(selected_plugins) - plugins.append(console_plugin(query, output_dir, plugins)) + code_generator = QueryCodegen( + schema_symbol, plugins=plugins, console_plugin=console_plugin + ) + code_generator.run(q.read_text()) - code_generator = QueryCodegen(schema_symbol, plugins=plugins) - code_generator.run(query.read_text()) + console_plugin.after_all_finished() diff --git a/strawberry/codegen/__init__.py b/strawberry/codegen/__init__.py index c21594351c..9609d0d674 100644 --- a/strawberry/codegen/__init__.py +++ b/strawberry/codegen/__init__.py @@ -1,3 +1,15 @@ -from .query_codegen import CodegenFile, CodegenResult, QueryCodegen, QueryCodegenPlugin +from .query_codegen import ( + CodegenFile, + CodegenResult, + ConsolePlugin, + QueryCodegen, + QueryCodegenPlugin, +) -__all__ = ["QueryCodegen", "QueryCodegenPlugin", "CodegenFile", "CodegenResult"] +__all__ = [ + "CodegenFile", + "CodegenResult", + "ConsolePlugin", + "QueryCodegen", + "QueryCodegenPlugin", +] diff --git a/strawberry/codegen/plugins/python.py b/strawberry/codegen/plugins/python.py index 41b17d4543..03f97aceb8 100644 --- a/strawberry/codegen/plugins/python.py +++ b/strawberry/codegen/plugins/python.py @@ -18,6 +18,8 @@ ) if TYPE_CHECKING: + from pathlib import Path + from strawberry.codegen.types import ( GraphQLArgumentValue, GraphQLField, @@ -46,8 +48,10 @@ class PythonPlugin(QueryCodegenPlugin): "Decimal": PythonType("Decimal", "decimal"), } - def __init__(self) -> None: + def __init__(self, query: Path) -> None: self.imports: Dict[str, Set[str]] = defaultdict(set) + self.outfile_name: str = query.with_suffix(".py").name + self.query = query def generate_code( self, types: List[GraphQLType], operation: GraphQLOperation @@ -57,7 +61,7 @@ def generate_code( code = imports + "\n\n" + "\n\n".join(printed_types) - return [CodegenFile("types.py", code.strip())] + return [CodegenFile(self.outfile_name, code.strip())] def _print_imports(self) -> str: imports = [ diff --git a/strawberry/codegen/plugins/typescript.py b/strawberry/codegen/plugins/typescript.py index 5d7e878f0b..ede9bd859d 100644 --- a/strawberry/codegen/plugins/typescript.py +++ b/strawberry/codegen/plugins/typescript.py @@ -14,6 +14,8 @@ ) if TYPE_CHECKING: + from pathlib import Path + from strawberry.codegen.types import GraphQLField, GraphQLOperation, GraphQLType @@ -33,12 +35,16 @@ class TypeScriptPlugin(QueryCodegenPlugin): float: "number", } + def __init__(self, query: Path) -> None: + self.outfile_name: str = query.with_suffix(".ts").name + self.query = query + def generate_code( self, types: List[GraphQLType], operation: GraphQLOperation ) -> List[CodegenFile]: printed_types = list(filter(None, (self._print_type(type) for type in types))) - return [CodegenFile("types.ts", "\n\n".join(printed_types))] + return [CodegenFile(self.outfile_name, "\n\n".join(printed_types))] def _get_type_name(self, type_: GraphQLType) -> str: if isinstance(type_, GraphQLOptional): diff --git a/strawberry/codegen/query_codegen.py b/strawberry/codegen/query_codegen.py index a3628324d2..f8724d6107 100644 --- a/strawberry/codegen/query_codegen.py +++ b/strawberry/codegen/query_codegen.py @@ -3,6 +3,7 @@ from dataclasses import MISSING, dataclass from enum import Enum from functools import cmp_to_key, partial +from pathlib import Path from typing import ( TYPE_CHECKING, Any, @@ -19,6 +20,7 @@ ) from typing_extensions import Literal, Protocol +import rich from graphql import ( BooleanValueNode, EnumValueNode, @@ -88,8 +90,6 @@ ) if TYPE_CHECKING: - from pathlib import Path - from graphql import ( ArgumentNode, DirectiveNode, @@ -131,6 +131,14 @@ class HasSelectionSet(Protocol): class QueryCodegenPlugin: + def __init__(self, query: Path) -> None: + """Initialize the plugin. + + The singular argument is the path to the file that is being processed + by this plugin. + """ + self.query = query + def on_start(self) -> None: ... @@ -143,6 +151,41 @@ def generate_code( return [] +class ConsolePlugin: + def __init__(self, output_dir: Path): + self.output_dir = output_dir + self.files_generated: List[Path] = [] + + def before_any_start(self) -> None: + rich.print( + "[bold yellow]The codegen is experimental. Please submit any bug at " + "https://github.com/strawberry-graphql/strawberry\n", + ) + + def after_all_finished(self) -> None: + rich.print("[green]Generated:") + for fname in self.files_generated: + rich.print(f" {fname}") + + def on_start(self, plugins: Iterable[QueryCodegenPlugin], query: Path) -> None: + plugin_names = [plugin.__class__.__name__ for plugin in plugins] + + rich.print( + f"[green]Generating code for {query} using " + f"{', '.join(plugin_names)} plugin(s)", + ) + + def on_end(self, result: CodegenResult) -> None: + self.output_dir.mkdir(parents=True, exist_ok=True) + result.write(self.output_dir) + + self.files_generated.extend(Path(cf.path) for cf in result.files) + + rich.print( + f"[green] Generated {len(result.files)} files in {self.output_dir}", + ) + + def _get_deps(t: GraphQLType) -> Iterable[GraphQLType]: """Get all the types that `t` depends on. @@ -195,8 +238,13 @@ def _py_to_graphql_value(obj: Any) -> GraphQLArgumentValue: class QueryCodegenPluginManager: - def __init__(self, plugins: List[QueryCodegenPlugin]) -> None: + def __init__( + self, + plugins: List[QueryCodegenPlugin], + console_plugin: Optional[ConsolePlugin] = None, + ) -> None: self.plugins = plugins + self.console_plugin = console_plugin def _sort_types(self, types: List[GraphQLType]) -> List[GraphQLType]: """Sort the types. @@ -234,6 +282,12 @@ def generate_code( return result def on_start(self) -> None: + if self.console_plugin and self.plugins: + # We need the query that we're processing + # just pick it off the first plugin + query = self.plugins[0].query + self.console_plugin.on_start(self.plugins, query) + for plugin in self.plugins: plugin.on_start() @@ -241,11 +295,19 @@ def on_end(self, result: CodegenResult) -> None: for plugin in self.plugins: plugin.on_end(result) + if self.console_plugin: + self.console_plugin.on_end(result) + class QueryCodegen: - def __init__(self, schema: Schema, plugins: List[QueryCodegenPlugin]): + def __init__( + self, + schema: Schema, + plugins: List[QueryCodegenPlugin], + console_plugin: Optional[ConsolePlugin] = None, + ): self.schema = schema - self.plugin_manager = QueryCodegenPluginManager(plugins) + self.plugin_manager = QueryCodegenPluginManager(plugins, console_plugin) self.types: List[GraphQLType] = [] def run(self, query: str) -> CodegenResult: diff --git a/tests/cli/test_codegen.py b/tests/cli/test_codegen.py index b5b631ace1..991eaaf488 100644 --- a/tests/cli/test_codegen.py +++ b/tests/cli/test_codegen.py @@ -56,6 +56,21 @@ def query_file_path(tmp_path: Path) -> Path: return output_path +@pytest.fixture +def query_file_path2(tmp_path: Path) -> Path: + output_path = tmp_path / "query2.graphql" + output_path.write_text( + """ + query GetUser { + user { + name + } + } + """ + ) + return output_path + + def test_codegen( cli_app: Typer, cli_runner: CliRunner, query_file_path: Path, tmp_path: Path ): @@ -82,6 +97,65 @@ def test_codegen( assert code_path.read_text() == "# This is a test file for GetUser" +def test_codegen_multiple_files( + cli_app: Typer, + cli_runner: CliRunner, + query_file_path: Path, + query_file_path2: Path, + tmp_path: Path, +): + expected_paths = [ + tmp_path / "query.py", + tmp_path / "query2.py", + tmp_path / "query.ts", + tmp_path / "query2.ts", + ] + for path in expected_paths: + assert not path.exists() + + selector = "tests.fixtures.sample_package.sample_module:schema" + result = cli_runner.invoke( + cli_app, + [ + "codegen", + "-p", + "python", + "-p", + "typescript", + "-o", + str(tmp_path), + "--schema", + selector, + str(query_file_path), + str(query_file_path2), + ], + ) + + assert result.exit_code == 0 + + for path in expected_paths: + assert path.exists() + assert " GetUserResult" in path.read_text() + + +def test_codegen_pass_no_query(cli_app: Typer, cli_runner: CliRunner, tmp_path: Path): + selector = "tests.fixtures.sample_package.sample_module:schema" + result = cli_runner.invoke( + cli_app, + [ + "codegen", + "-p", + "tests.cli.test_codegen:EmptyPlugin", + "-o", + str(tmp_path), + "--schema", + selector, + ], + ) + + assert result.exit_code == 0 + + def test_codegen_passing_plugin_symbol( cli_app: Typer, cli_runner: CliRunner, query_file_path: Path, tmp_path: Path ): @@ -197,7 +271,7 @@ def test_codegen_finds_our_plugins( assert result.exit_code == 0 - code_path = tmp_path / "types.py" + code_path = tmp_path / query_file_path.with_suffix(".py").name assert code_path.exists() assert "class GetUserResult" in code_path.read_text() diff --git a/tests/codegen/test_print_operation.py b/tests/codegen/test_print_operation.py index fd99a3997e..daad7b406c 100644 --- a/tests/codegen/test_print_operation.py +++ b/tests/codegen/test_print_operation.py @@ -14,7 +14,7 @@ def test_codegen( query: Path, schema, ): - generator = QueryCodegen(schema, plugins=[PrintOperationPlugin()]) + generator = QueryCodegen(schema, plugins=[PrintOperationPlugin(query)]) query_content = query.read_text() result = generator.run(query_content) diff --git a/tests/codegen/test_query_codegen.py b/tests/codegen/test_query_codegen.py index a016205a06..78a3b71259 100644 --- a/tests/codegen/test_query_codegen.py +++ b/tests/codegen/test_query_codegen.py @@ -40,7 +40,7 @@ def test_codegen( snapshot: Snapshot, schema, ): - generator = QueryCodegen(schema, plugins=[plugin_class()]) + generator = QueryCodegen(schema, plugins=[plugin_class(query)]) result = generator.run(query.read_text()) @@ -50,22 +50,37 @@ def test_codegen( snapshot.assert_match(code, f"{query.with_suffix('').stem}.{extension}") -def test_codegen_fails_if_no_operation_name(schema): - generator = QueryCodegen(schema, plugins=[PythonPlugin()]) +def test_codegen_fails_if_no_operation_name(schema, tmp_path): + query = tmp_path / "query.graphql" + data = "query { hello }" + with query.open("w") as f: + f.write(data) + + generator = QueryCodegen(schema, plugins=[PythonPlugin(query)]) with pytest.raises(NoOperationNameProvidedError): - generator.run("query { hello }") + generator.run(data) + +def test_codegen_fails_if_no_operation(schema, tmp_path): + query = tmp_path / "query.graphql" + data = "type X { hello: String }" + with query.open("w") as f: + f.write(data) -def test_codegen_fails_if_no_operation(schema): - generator = QueryCodegen(schema, plugins=[PythonPlugin()]) + generator = QueryCodegen(schema, plugins=[PythonPlugin(query)]) with pytest.raises(NoOperationProvidedError): - generator.run("type X { hello: String }") + generator.run(data) + +def test_fails_with_multiple_operations(schema, tmp_path): + query = tmp_path / "query.graphql" + data = "query { hello } query { world }" + with query.open("w") as f: + f.write(data) -def test_fails_with_multiple_operations(schema): - generator = QueryCodegen(schema, plugins=[PythonPlugin()]) + generator = QueryCodegen(schema, plugins=[PythonPlugin(query)]) with pytest.raises(MultipleOperationsProvidedError): - generator.run("query { hello } query { world }") + generator.run(data)