Skip to content

Commit

Permalink
feat: Creating one py_binary per main module (#1584)
Browse files Browse the repository at this point in the history
Many existing Python repos don't use `__main__.py` to indicate the the
main module. Instead, they put something like below in any Python files:

```python
if __name__ == "__main__":
  main()
```

This PR makes the Gazelle extension able to recognize main modules like
this, when `__main__.py` doesn't exist. This reduces the need to create
`__main__.py` when enabling Gazelle extensions in existing Python repos.

The current behavior of creating single `py_binary` for `__main__.py` is
preserved and takes precedence. So this is a backward-compatible change.

Closes #1566.
  • Loading branch information
linzhp authored Dec 13, 2023
1 parent 27450f9 commit 6ffb04e
Show file tree
Hide file tree
Showing 17 changed files with 253 additions and 60 deletions.
1 change: 1 addition & 0 deletions examples/bzlmod/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ compile_pip_requirements_3_9(
# with pip-compile.
compile_pip_requirements_3_10(
name = "requirements_3_10",
timeout = "moderate",
src = "requirements.in",
requirements_txt = "requirements_lock_3_10.txt",
requirements_windows = "requirements_windows_3_10.txt",
Expand Down
11 changes: 10 additions & 1 deletion gazelle/python/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
load("@bazel_gazelle//:def.bzl", "gazelle_binary")
load("@io_bazel_rules_go//go:def.bzl", "go_library")
load("@rules_python//python:defs.bzl", "py_binary")
load("@rules_python//python:defs.bzl", "py_binary", "py_test")
load(":gazelle_test.bzl", "gazelle_test")

go_library(
Expand Down Expand Up @@ -58,6 +58,15 @@ py_binary(
visibility = ["//visibility:public"],
)

py_test(
name = "parse_test",
srcs = [
"parse.py",
"parse_test.py",
],
imports = ["."],
)

filegroup(
name = "helper.zip",
srcs = [":helper"],
Expand Down
2 changes: 1 addition & 1 deletion gazelle/python/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

if __name__ == "__main__":
if len(sys.argv) < 2:
sys.exit("Please provide subcommand, either print or std_modules")
sys.exit("Please provide subcommand, either parse or std_modules")
if sys.argv[1] == "parse":
sys.exit(parse.main(sys.stdin, sys.stdout))
elif sys.argv[1] == "std_modules":
Expand Down
115 changes: 65 additions & 50 deletions gazelle/python/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"log"
"os"
"path/filepath"
"sort"
"strings"

"github.com/bazelbuild/bazel-gazelle/config"
Expand Down Expand Up @@ -89,9 +90,9 @@ func (py *Python) GenerateRules(args language.GenerateArgs) language.GenerateRes
pyTestFilenames := treeset.NewWith(godsutils.StringComparator)
pyFileNames := treeset.NewWith(godsutils.StringComparator)

// hasPyBinary controls whether a py_binary target should be generated for
// hasPyBinaryEntryPointFile controls whether a single py_binary target should be generated for
// this package or not.
hasPyBinary := false
hasPyBinaryEntryPointFile := false

// hasPyTestEntryPointFile and hasPyTestEntryPointTarget control whether a py_test target should
// be generated for this package or not.
Expand All @@ -106,8 +107,8 @@ func (py *Python) GenerateRules(args language.GenerateArgs) language.GenerateRes
ext := filepath.Ext(f)
if ext == ".py" {
pyFileNames.Add(f)
if !hasPyBinary && f == pyBinaryEntrypointFilename {
hasPyBinary = true
if !hasPyBinaryEntryPointFile && f == pyBinaryEntrypointFilename {
hasPyBinaryEntryPointFile = true
} else if !hasPyTestEntryPointFile && f == pyTestEntrypointFilename {
hasPyTestEntryPointFile = true
} else if f == conftestFilename {
Expand Down Expand Up @@ -219,7 +220,7 @@ func (py *Python) GenerateRules(args language.GenerateArgs) language.GenerateRes
collisionErrors := singlylinkedlist.New()

appendPyLibrary := func(srcs *treeset.Set, pyLibraryTargetName string) {
deps, err := parser.parse(srcs)
deps, mainModules, err := parser.parse(srcs)
if err != nil {
log.Fatalf("ERROR: %v\n", err)
}
Expand All @@ -228,16 +229,33 @@ func (py *Python) GenerateRules(args language.GenerateArgs) language.GenerateRes
// exists, and if it is of a different kind from the one we are
// generating. If so, we have to throw an error since Gazelle won't
// generate it correctly.
if args.File != nil {
for _, t := range args.File.Rules {
if t.Name() == pyLibraryTargetName && t.Kind() != actualPyLibraryKind {
fqTarget := label.New("", args.Rel, pyLibraryTargetName)
err := fmt.Errorf("failed to generate target %q of kind %q: "+
"a target of kind %q with the same name already exists. "+
"Use the '# gazelle:%s' directive to change the naming convention.",
fqTarget.String(), actualPyLibraryKind, t.Kind(), pythonconfig.LibraryNamingConvention)
collisionErrors.Add(err)
if err := ensureNoCollision(args.File, pyLibraryTargetName, actualPyLibraryKind); err != nil {
fqTarget := label.New("", args.Rel, pyLibraryTargetName)
err := fmt.Errorf("failed to generate target %q of kind %q: %w. "+
"Use the '# gazelle:%s' directive to change the naming convention.",
fqTarget.String(), actualPyLibraryKind, err, pythonconfig.LibraryNamingConvention)
collisionErrors.Add(err)
}

if !hasPyBinaryEntryPointFile {
sort.Strings(mainModules)
// Creating one py_binary target per main module when __main__.py doesn't exist.
for _, filename := range mainModules {
pyBinaryTargetName := strings.TrimSuffix(filepath.Base(filename), ".py")
if err := ensureNoCollision(args.File, pyBinaryTargetName, actualPyBinaryKind); err != nil {
fqTarget := label.New("", args.Rel, pyBinaryTargetName)
log.Printf("failed to generate target %q of kind %q: %v",
fqTarget.String(), actualPyBinaryKind, err)
continue
}
srcs.Remove(filename)
pyBinary := newTargetBuilder(pyBinaryKind, pyBinaryTargetName, pythonProjectRoot, args.Rel, pyFileNames).
addVisibility(visibility).
addSrc(filename).
addModuleDependencies(deps).
generateImportsAttribute().build()
result.Gen = append(result.Gen, pyBinary)
result.Imports = append(result.Imports, pyBinary.PrivateAttr(config.GazelleImportsKey))
}
}

Expand Down Expand Up @@ -270,8 +288,8 @@ func (py *Python) GenerateRules(args language.GenerateArgs) language.GenerateRes
appendPyLibrary(pyLibraryFilenames, cfg.RenderLibraryName(packageName))
}

if hasPyBinary {
deps, err := parser.parseSingle(pyBinaryEntrypointFilename)
if hasPyBinaryEntryPointFile {
deps, _, err := parser.parseSingle(pyBinaryEntrypointFilename)
if err != nil {
log.Fatalf("ERROR: %v\n", err)
}
Expand All @@ -282,17 +300,12 @@ func (py *Python) GenerateRules(args language.GenerateArgs) language.GenerateRes
// exists, and if it is of a different kind from the one we are
// generating. If so, we have to throw an error since Gazelle won't
// generate it correctly.
if args.File != nil {
for _, t := range args.File.Rules {
if t.Name() == pyBinaryTargetName && t.Kind() != actualPyBinaryKind {
fqTarget := label.New("", args.Rel, pyBinaryTargetName)
err := fmt.Errorf("failed to generate target %q of kind %q: "+
"a target of kind %q with the same name already exists. "+
"Use the '# gazelle:%s' directive to change the naming convention.",
fqTarget.String(), actualPyBinaryKind, t.Kind(), pythonconfig.BinaryNamingConvention)
collisionErrors.Add(err)
}
}
if err := ensureNoCollision(args.File, pyBinaryTargetName, actualPyBinaryKind); err != nil {
fqTarget := label.New("", args.Rel, pyBinaryTargetName)
err := fmt.Errorf("failed to generate target %q of kind %q: %w. "+
"Use the '# gazelle:%s' directive to change the naming convention.",
fqTarget.String(), actualPyBinaryKind, err, pythonconfig.BinaryNamingConvention)
collisionErrors.Add(err)
}

pyBinaryTarget := newTargetBuilder(pyBinaryKind, pyBinaryTargetName, pythonProjectRoot, args.Rel, pyFileNames).
Expand All @@ -310,7 +323,7 @@ func (py *Python) GenerateRules(args language.GenerateArgs) language.GenerateRes

var conftest *rule.Rule
if hasConftestFile {
deps, err := parser.parseSingle(conftestFilename)
deps, _, err := parser.parseSingle(conftestFilename)
if err != nil {
log.Fatalf("ERROR: %v\n", err)
}
Expand All @@ -319,16 +332,11 @@ func (py *Python) GenerateRules(args language.GenerateArgs) language.GenerateRes
// exists, and if it is of a different kind from the one we are
// generating. If so, we have to throw an error since Gazelle won't
// generate it correctly.
if args.File != nil {
for _, t := range args.File.Rules {
if t.Name() == conftestTargetname && t.Kind() != actualPyLibraryKind {
fqTarget := label.New("", args.Rel, conftestTargetname)
err := fmt.Errorf("failed to generate target %q of kind %q: "+
"a target of kind %q with the same name already exists.",
fqTarget.String(), actualPyLibraryKind, t.Kind())
collisionErrors.Add(err)
}
}
if err := ensureNoCollision(args.File, conftestTargetname, actualPyLibraryKind); err != nil {
fqTarget := label.New("", args.Rel, conftestTargetname)
err := fmt.Errorf("failed to generate target %q of kind %q: %w. ",
fqTarget.String(), actualPyLibraryKind, err)
collisionErrors.Add(err)
}

conftestTarget := newTargetBuilder(pyLibraryKind, conftestTargetname, pythonProjectRoot, args.Rel, pyFileNames).
Expand All @@ -346,25 +354,20 @@ func (py *Python) GenerateRules(args language.GenerateArgs) language.GenerateRes

var pyTestTargets []*targetBuilder
newPyTestTargetBuilder := func(srcs *treeset.Set, pyTestTargetName string) *targetBuilder {
deps, err := parser.parse(srcs)
deps, _, err := parser.parse(srcs)
if err != nil {
log.Fatalf("ERROR: %v\n", err)
}
// Check if a target with the same name we are generating already
// exists, and if it is of a different kind from the one we are
// generating. If so, we have to throw an error since Gazelle won't
// generate it correctly.
if args.File != nil {
for _, t := range args.File.Rules {
if t.Name() == pyTestTargetName && t.Kind() != actualPyTestKind {
fqTarget := label.New("", args.Rel, pyTestTargetName)
err := fmt.Errorf("failed to generate target %q of kind %q: "+
"a target of kind %q with the same name already exists. "+
"Use the '# gazelle:%s' directive to change the naming convention.",
fqTarget.String(), actualPyTestKind, t.Kind(), pythonconfig.TestNamingConvention)
collisionErrors.Add(err)
}
}
if err := ensureNoCollision(args.File, pyTestTargetName, actualPyTestKind); err != nil {
fqTarget := label.New("", args.Rel, pyTestTargetName)
err := fmt.Errorf("failed to generate target %q of kind %q: %w. "+
"Use the '# gazelle:%s' directive to change the naming convention.",
fqTarget.String(), actualPyTestKind, err, pythonconfig.TestNamingConvention)
collisionErrors.Add(err)
}
return newTargetBuilder(pyTestKind, pyTestTargetName, pythonProjectRoot, args.Rel, pyFileNames).
addSrcs(srcs).
Expand Down Expand Up @@ -476,3 +479,15 @@ func isEntrypointFile(path string) bool {
return false
}
}

func ensureNoCollision(file *rule.File, targetName, kind string) error {
if file == nil {
return nil
}
for _, t := range file.Rules {
if t.Name() == targetName && t.Kind() != kind {
return fmt.Errorf("a target of kind %q with the same name already exists", t.Kind())
}
}
return nil
}
31 changes: 30 additions & 1 deletion gazelle/python/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import os
import sys
from io import BytesIO
from tokenize import COMMENT, tokenize
from tokenize import COMMENT, NAME, OP, STRING, tokenize


def parse_import_statements(content, filepath):
Expand Down Expand Up @@ -59,6 +59,30 @@ def parse_comments(content):
return comments


def parse_main(content):
g = tokenize(BytesIO(content.encode("utf-8")).readline)
for token_type, token_val, start, _, _ in g:
if token_type != NAME or token_val != "if" or start[1] != 0:
continue
try:
token_type, token_val, start, _, _ = next(g)
if token_type != NAME or token_val != "__name__":
continue
token_type, token_val, start, _, _ = next(g)
if token_type != OP or token_val != "==":
continue
token_type, token_val, start, _, _ = next(g)
if token_type != STRING or token_val.strip("\"'") != '__main__':
continue
token_type, token_val, start, _, _ = next(g)
if token_type != OP or token_val != ":":
continue
return True
except StopIteration:
break
return False


def parse(repo_root, rel_package_path, filename):
rel_filepath = os.path.join(rel_package_path, filename)
abs_filepath = os.path.join(repo_root, rel_filepath)
Expand All @@ -70,11 +94,16 @@ def parse(repo_root, rel_package_path, filename):
parse_import_statements, content, rel_filepath
)
comments_future = executor.submit(parse_comments, content)
main_future = executor.submit(parse_main, content)
modules = modules_future.result()
comments = comments_future.result()
has_main = main_future.result()

output = {
"filename": filename,
"modules": modules,
"comments": comments,
"has_main": has_main,
}
return output

Expand Down
39 changes: 39 additions & 0 deletions gazelle/python/parse_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import unittest
import parse

class TestParse(unittest.TestCase):
def test_not_has_main(self):
content = "a = 1\nb = 2"
self.assertFalse(parse.parse_main(content))

def test_has_main_in_function(self):
content = """
def foo():
if __name__ == "__main__":
a = 3
"""
self.assertFalse(parse.parse_main(content))

def test_has_main(self):
content = """
import unittest
from lib import main
class ExampleTest(unittest.TestCase):
def test_main(self):
self.assertEqual(
"",
main([["A", 1], ["B", 2]]),
)
if __name__ == "__main__":
unittest.main()
"""
self.assertTrue(parse.parse_main(content))


if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit 6ffb04e

Please sign in to comment.