Skip to content

Commit

Permalink
Add ability for codegen to handle multiple input query files.
Browse files Browse the repository at this point in the history
  • Loading branch information
Matt Gilson committed Jul 3, 2023
1 parent ccc5b52 commit 1a41313
Show file tree
Hide file tree
Showing 5 changed files with 254 additions and 12 deletions.
93 changes: 83 additions & 10 deletions strawberry/cli/commands/codegen.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from __future__ import annotations

import contextlib
import functools
import importlib
import inspect
from pathlib import Path # noqa: TCH003
import warnings
from pathlib import Path
from typing import TYPE_CHECKING, List, Optional, Type

import rich
Expand All @@ -11,6 +14,8 @@
from strawberry.cli.app import app
from strawberry.cli.utils import load_schema
from strawberry.codegen import QueryCodegen, QueryCodegenPlugin
from strawberry.codegen.plugins.python import PythonPlugin
from strawberry.codegen.plugins.typescript import TypeScriptPlugin

if TYPE_CHECKING:
from strawberry.codegen import CodegenResult
Expand Down Expand Up @@ -62,6 +67,7 @@ def _import_plugin(plugin: str) -> Optional[Type[QueryCodegenPlugin]]:
return None


@functools.lru_cache()
def _load_plugin(plugin_path: str) -> Type[QueryCodegenPlugin]:
# try to import plugin_name from current folder
# then try to import from strawberry.codegen.plugins
Expand All @@ -78,8 +84,26 @@ def _load_plugin(plugin_path: str) -> Type[QueryCodegenPlugin]:
return plugin


def _load_plugins(plugins: List[str]) -> List[QueryCodegenPlugin]:
return [_load_plugin(plugin)() for plugin in plugins]
def _load_plugins(plugin_ids: List[str], query: Path) -> List[QueryCodegenPlugin]:
plugins = []
for ptype_id in plugin_ids:
ptype = _load_plugin(ptype_id)
# inject arguments into the plugin's `__init__` based on
# the names in the signature (e.g., similar to flake8 plugins).
sig = inspect.signature(ptype)
kwargs = {}
if "query" in sig.parameters:
kwargs["query"] = query
if "query_file" in sig.parameters:
kwargs["query_file"] = query

plugin = ptype(**kwargs)
if not hasattr(plugin, "query"):
plugin.query = query

plugins.append(plugin)

return plugins


class ConsolePlugin(QueryCodegenPlugin):
Expand All @@ -89,13 +113,20 @@ def __init__(
self.query = query
self.output_dir = output_dir
self.plugins = plugins
self.files_generated: List[Path] = []

def on_start(self) -> None:
def before_any_start(self) -> None:
rich.print(
"[bold yellow]The codegen is experimental. Please submit any bug at "
"https://github.com/strawberry-graphql/strawberry\n",
)

def after_all_finished(self) -> None:
rich.print("[green]Generated:")
for fname in self.files_generated:
rich.print(f" {fname}")

def on_start(self) -> None:
plugin_names = [plugin.__class__.__name__ for plugin in self.plugins]

rich.print(
Expand All @@ -107,14 +138,18 @@ def on_end(self, result: CodegenResult) -> None:
self.output_dir.mkdir(parents=True, exist_ok=True)
result.write(self.output_dir)

self.files_generated.extend(Path(cf.path) for cf in result.files)

rich.print(
f"[green] Generated {len(result.files)} files in {self.output_dir}",
)


@app.command(help="Generate code from a query")
def codegen(
query: Path = typer.Argument(..., exists=True, dir_okay=False),
query: Optional[List[Path]] = typer.Argument(
default=None, exists=True, dir_okay=False
),
schema: str = typer.Option(..., help="Python path to the schema file"),
app_dir: str = typer.Option(
".",
Expand Down Expand Up @@ -143,12 +178,50 @@ def codegen(
),
cli_plugin: Optional[str] = None,
) -> None:
if not query:
return

schema_symbol = load_schema(schema, app_dir)

console_plugin = _load_plugin(cli_plugin) if cli_plugin else ConsolePlugin
console_plugin_type = _load_plugin(cli_plugin) if cli_plugin else ConsolePlugin

plugins = _load_plugins(selected_plugins)
plugins.append(console_plugin(query, output_dir, plugins))
plugins = _load_plugins(selected_plugins, query[0])
console_plugin = console_plugin_type(query[0], output_dir, plugins)

if not isinstance(console_plugin, ConsolePlugin):
warnings.warn(
"The ConsolePlugin should inherit from ``{__name__}.ConsolePlugin``.",
DeprecationWarning,
stacklevel=1,
)

code_generator = QueryCodegen(schema_symbol, plugins=plugins)
code_generator.run(query.read_text())
with contextlib.suppress(AttributeError):
# This method was not part of the original implementation of the ConsolePlugin
# and the ConsolePlugin is overridable on the CLI. It isn't guaranteed that
# custom ConsolePlugin implementations inherit from ``ConsolePlugin``, so if
# the implementor does not have this method, we will just keep going.
console_plugin.before_any_start()

for q in query:
plugins = _load_plugins(selected_plugins, q)
console_plugin.query = q # update the query in the console plugin.
plugins.append(console_plugin)

for p in plugins:
# Adjust the names of the output files for the buildin plugins if there
# are multiple files being generated.
if len(query) > 1:
if isinstance(p, PythonPlugin):
p.outfile_name = q.stem + ".py"
elif isinstance(p, TypeScriptPlugin):
p.outfile_name = q.stem + ".ts"

code_generator = QueryCodegen(schema_symbol, plugins=plugins)
code_generator.run(q.read_text())

with contextlib.suppress(AttributeError):
# This method was not part of the original implementation of the ConsolePlugin
# and the ConsolePlugin is overridable on the CLI. It isn't guaranteed that
# custom ConsolePlugin implementations inherit from ``ConsolePlugin``, so if
# the implementor does not have this method, we will just keep going.
console_plugin.after_all_finished()
6 changes: 5 additions & 1 deletion strawberry/codegen/plugins/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
)


DEFAULT_OUTFILE_NAME = "types.py"


@dataclass
class PythonType:
type: str
Expand All @@ -47,6 +50,7 @@ class PythonPlugin(QueryCodegenPlugin):

def __init__(self) -> None:
self.imports: Dict[str, Set[str]] = defaultdict(set)
self.outfile_name: str = DEFAULT_OUTFILE_NAME

def generate_code(
self, types: List[GraphQLType], operation: GraphQLOperation
Expand All @@ -56,7 +60,7 @@ def generate_code(

code = imports + "\n\n" + "\n\n".join(printed_types)

return [CodegenFile("types.py", code.strip())]
return [CodegenFile(self.outfile_name, code.strip())]

def _print_imports(self) -> str:
imports = [
Expand Down
8 changes: 7 additions & 1 deletion strawberry/codegen/plugins/typescript.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
from strawberry.codegen.types import GraphQLField, GraphQLOperation, GraphQLType


DEFAULT_OUTFILE_NAME = "types.ts"


class TypeScriptPlugin(QueryCodegenPlugin):
SCALARS_TO_TS_TYPE = {
"ID": "string",
Expand All @@ -33,12 +36,15 @@ class TypeScriptPlugin(QueryCodegenPlugin):
float: "number",
}

def __init__(self) -> None:
self.outfile_name: str = DEFAULT_OUTFILE_NAME

def generate_code(
self, types: List[GraphQLType], operation: GraphQLOperation
) -> List[CodegenFile]:
printed_types = list(filter(None, (self._print_type(type) for type in types)))

return [CodegenFile("types.ts", "\n\n".join(printed_types))]
return [CodegenFile(self.outfile_name, "\n\n".join(printed_types))]

def _get_type_name(self, type_: GraphQLType) -> str:
if isinstance(type_, GraphQLOptional):
Expand Down
6 changes: 6 additions & 0 deletions strawberry/codegen/query_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,12 @@ class HasSelectionSet(Protocol):


class QueryCodegenPlugin:
#: That path that holds the query being processed.
#: This gets passed to __init__ if there is a parameter named ``query`` or
#: ``query_file``. Otherwise, it gets set after ``__init__`` for backward
#: compatibility.
query: Path

def on_start(self) -> None:
...

Expand Down
Loading

0 comments on commit 1a41313

Please sign in to comment.