From 2a0e96a4c53f75338d6829f68a603111bbbfc9e8 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sun, 19 Mar 2023 07:20:08 -0400 Subject: [PATCH] feat(protoc): forward protoc arguments to protoc generator --- protoletariat/__main__.py | 23 +++---- protoletariat/fdsetgen.py | 21 +++--- protoletariat/tests/conftest.py | 116 ++++++++++++++++++++++++++++++++ 3 files changed, 134 insertions(+), 26 deletions(-) diff --git a/protoletariat/__main__.py b/protoletariat/__main__.py index c3a8d6e4..1127038f 100644 --- a/protoletariat/__main__.py +++ b/protoletariat/__main__.py @@ -5,7 +5,7 @@ import os import sys from pathlib import Path -from typing import IO +from typing import IO, Iterable import click @@ -100,7 +100,10 @@ def main( ) -@main.command(help="Use protoc to generate the FileDescriptorSet blob") +@main.command( + context_settings=dict(ignore_unknown_options=True), + help="Use protoc to generate the FileDescriptorSet blob", +) @click.option( "--protoc-path", envvar="PROTOC_PATH", @@ -124,28 +127,18 @@ def main( ), help="Protobuf file search path(s). Accepts multiple values.", ) -@click.argument( - "proto_files", - nargs=-1, - required=True, - type=click.Path( - file_okay=True, - dir_okay=False, - exists=True, - path_type=Path, - ), -) +@click.argument("protoc_args", nargs=-1, type=click.UNPROCESSED) @click.pass_context def protoc( ctx: click.Context, protoc_path: str, proto_paths: list[Path], - proto_files: list[Path], + protoc_args: Iterable[str], ) -> None: Protoc( protoc_path=os.fsdecode(protoc_path), - proto_files=[Path(os.fsdecode(proto_file)) for proto_file in proto_files], proto_paths=[Path(os.fsdecode(proto_path)) for proto_path in proto_paths], + protoc_args=list(protoc_args), ).fix_imports(**ctx.obj) diff --git a/protoletariat/fdsetgen.py b/protoletariat/fdsetgen.py index fc076dcf..7748a113 100644 --- a/protoletariat/fdsetgen.py +++ b/protoletariat/fdsetgen.py @@ -143,25 +143,24 @@ def __init__( self, *, protoc_path: str, - proto_files: Iterable[Path], proto_paths: Iterable[Path], + protoc_args: Iterable[str], ) -> None: self.protoc_path = protoc_path - self.proto_files = proto_files self.proto_paths = proto_paths + self.protoc_args = protoc_args def generate_file_descriptor_set_bytes(self) -> bytes: with tempfile.NamedTemporaryFile(delete=False) as f: filename = Path(f.name) - subprocess.check_output( - [ - *shlex.split(self.protoc_path), - "--include_imports", - f"--descriptor_set_out={filename}", - *map("--proto_path={}".format, self.proto_paths), - *map(str, self.proto_files), - ] - ) + args = [ + *shlex.split(self.protoc_path), + "--include_imports", + f"--descriptor_set_out={filename}", + *map("--proto_path={}".format, self.proto_paths), + *self.protoc_args, + ] + subprocess.check_output(args) try: return filename.read_bytes() diff --git a/protoletariat/tests/conftest.py b/protoletariat/tests/conftest.py index 6fc24ba7..b64ff0f0 100644 --- a/protoletariat/tests/conftest.py +++ b/protoletariat/tests/conftest.py @@ -208,6 +208,83 @@ class GrpcIoToolsFixture(ProtocFixture): protoc_exe = sys.executable, "-m", "grpc_tools.protoc" +class RawProtocFixture(ProtoletariatFixture): + protoc_exe = ("protoc",) + + def __init__( + self, + *, + base_dir: Path, + package: str, + proto_texts: Iterable[ProtoFile], + monkeypatch: pytest.MonkeyPatch, + grpc: bool = False, + mypy: bool = False, + mypy_grpc: bool = False, + ) -> None: + super().__init__( + base_dir=base_dir, + package=package, + proto_texts=proto_texts, + monkeypatch=monkeypatch, + ) + self.grpc = grpc + self.mypy = mypy + self.mypy_grpc = mypy_grpc + + def do_generate(self, cli: CliRunner, *, args: Iterable[str] = ()) -> Result: + with tempfile.NamedTemporaryFile(delete=False) as f: + filename = f.name + + protoc_args = [ + "protoc", + "--include_imports", + f"--descriptor_set_out={filename}", + "--proto_path", + str(self.base_dir), + "--python_out", + str(self.package_dir), + *(str(fn) for fn, _ in self.proto_texts), + ] + + if self.grpc: + # XXX: why isn't this found? PATH is set properly + grpc_python_plugin = shutil.which("grpc_python_plugin") + protoc_args.extend( + ( + f"--plugin=protoc-gen-grpc_python={grpc_python_plugin}", + "--grpc_python_out", + str(self.package_dir), + ) + ) + if self.mypy: + protoc_args.extend(("--mypy_out", str(self.package_dir))) + if self.mypy_grpc: + protoc_args.extend(("--mypy_grpc_out", str(self.package_dir))) + + subprocess.check_call(protoc_args) + + try: + return cli.invoke( + main, + [ + "--python-out", + str(self.package_dir), + *args, + "protoc", + "--protoc-path", + shlex.join(self.protoc_exe), + "--proto-path", + str(self.base_dir), + f"--descriptor_set_in={filename}", + *(str(filename) for filename, _ in self.proto_texts), + ], + catch_exceptions=False, + ) + finally: + os.unlink(filename) + + class RawFixture(ProtoletariatFixture): def __init__( self, @@ -337,6 +414,10 @@ def basic_cli_texts(request: SubRequest) -> list[ProtoFile]: partial(RawFixture, package="basic_cli"), id="basic_cli_raw", ), + pytest.param( + partial(RawProtocFixture, package="basic_cli"), + id="basic_cli_raw_protoc", + ), ] ) def basic_cli( @@ -420,6 +501,10 @@ def thing_service_texts(request: SubRequest) -> list[ProtoFile]: partial(RawFixture, package="thing_service", grpc=True), id="thing_service_raw", ), + pytest.param( + partial(RawProtocFixture, package="thing_service", grpc=True), + id="thing_service_raw_protoc", + ), ] ) def thing_service( @@ -475,6 +560,9 @@ def nested_texts() -> list[ProtoFile]: partial(GrpcIoToolsFixture, package="nested"), id="nested_grpc_io_tools" ), pytest.param(partial(RawFixture, package="nested"), id="nested_raw"), + pytest.param( + partial(RawProtocFixture, package="nested"), id="nested_raw_protoc" + ), ] ) def nested( @@ -553,6 +641,16 @@ def no_imports_service_texts(request: SubRequest) -> list[ProtoFile]: ), id="no_imports_service_raw", ), + pytest.param( + partial( + RawProtocFixture, + package="no_imports_service", + grpc=True, + mypy=True, + mypy_grpc=True, + ), + id="no_imports_service_raw_protoc", + ), ] ) def no_imports_service( @@ -652,6 +750,16 @@ def imports_service_texts(request: SubRequest) -> list[ProtoFile]: ), id="imports_service_raw", ), + pytest.param( + partial( + RawProtocFixture, + package="imports_service", + grpc=True, + mypy=True, + mypy_grpc=True, + ), + id="imports_service_raw_protoc", + ), ] ) def grpc_imports( @@ -708,6 +816,10 @@ def long_names_texts() -> list[ProtoFile]: partial(RawFixture, package="long_names", mypy=True), id="long_names_raw", ), + pytest.param( + partial(RawProtocFixture, package="long_names", mypy=True), + id="long_names_raw_protoc", + ), ] ) def long_names( @@ -763,6 +875,10 @@ def ignored_import_texts(request: SubRequest) -> list[ProtoFile]: partial(RawFixture, package="ignored_imports"), id="ignored_imports_raw", ), + pytest.param( + partial(RawProtocFixture, package="ignored_imports"), + id="ignored_imports_raw_protoc", + ), ] ) def ignored_imports(