Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor and Improve Build and Clean Commands #2917

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 58 additions & 73 deletions build_helpers/build_helpers.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import codecs
import errno
import logging
import os
import re
import shutil
import subprocess
from functools import partial
from os.path import abspath, basename, dirname, exists, isdir, join
from os.path import abspath, dirname, exists, isdir, join
from pathlib import Path
from typing import List, Optional

Expand All @@ -16,32 +15,30 @@

log = logging.getLogger(__name__)


def find_version(*file_paths: str) -> str:
with codecs.open(os.path.join(*file_paths), "r") as fp:
version_file = fp.read()
version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", version_file, re.M)
if version_match:
return version_match.group(1)
raise RuntimeError("Unable to find version string.")

"""Retrieve the version string from the specified file."""
version_file_path = os.path.join(*file_paths)
with codecs.open(version_file_path, "r", "utf-8") as file:
content = file.read()
match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", content, re.M)
if match:
return match.group(1)
raise RuntimeError(f"Unable to find version string in {version_file_path}.")

def matches(patterns: List[str], string: str) -> bool:
"""Check if the string matches any of the given regex patterns."""
string = string.replace("\\", "/")
for pattern in patterns:
if re.match(pattern, string):
return True
return False
return any(re.match(pattern, string) for pattern in patterns)


def find_(
def find_files(
root: str,
rbase: str,
include_files: List[str],
include_dirs: List[str],
excludes: List[str],
scan_exclude: List[str],
) -> List[str]:
"""Recursively find files and directories matching the given patterns."""
files = []
scan_root = os.path.join(root, rbase)
with os.scandir(scan_root) as it:
Expand All @@ -55,32 +52,29 @@ def find_(
if not matches(excludes, path):
files.append(path)
else:
ret = find_(
files.extend(find_files(
root=root,
rbase=path,
include_files=include_files,
include_dirs=include_dirs,
excludes=excludes,
scan_exclude=scan_exclude,
)
files.extend(ret)
else:
if matches(include_files, path) and not matches(excludes, path):
files.append(path)
))
elif matches(include_files, path) and not matches(excludes, path):
files.append(path)

return files


def find(
root: str,
include_files: List[str],
include_dirs: List[str],
excludes: List[str],
scan_exclude: Optional[List[str]] = None,
) -> List[str]:
if scan_exclude is None:
scan_exclude = []
return find_(
"""Find all files and directories matching the given patterns starting from the root."""
scan_exclude = scan_exclude or []
return find_files(
root=root,
rbase="",
include_files=include_files,
Expand All @@ -89,14 +83,10 @@ def find(
scan_exclude=scan_exclude,
)


class CleanCommand(Command): # type: ignore
"""
Our custom command to clean out junk files.
"""
class CleanCommand(Command):
"""Custom command to clean out generated and junk files."""

description = "Cleans out generated and junk files we don't want in the repo"
dry_run: bool
user_options: List[str] = []

def run(self) -> None:
Expand All @@ -116,74 +106,69 @@ def run(self) -> None:
)

if self.dry_run:
print("Would clean up the following files and dirs")
print("Would clean up the following files and dirs:")
print("\n".join(files))
else:
for f in files:
if exists(f):
if isdir(f):
shutil.rmtree(f, ignore_errors=True)
for file_path in files:
if exists(file_path):
if isdir(file_path):
shutil.rmtree(file_path, ignore_errors=True)
else:
os.unlink(f)
os.unlink(file_path)

def initialize_options(self) -> None:
pass

def finalize_options(self) -> None:
pass


def run_antlr(cmd: Command) -> None:
"""Execute the ANTLR command to generate parsers."""
try:
log.info("Generating parsers with antlr4")
cmd.run_command("antlr")
except OSError as e:
if e.errno == errno.ENOENT:
msg = f"| Unable to generate parsers: {e} |"
msg = "=" * len(msg) + "\n" + msg + "\n" + "=" * len(msg)
log.critical(f"{msg}")
if e.errno == os.errno.ENOENT:
msg = f"Unable to generate parsers: {e}"
log.critical(f"{'=' * len(msg)}\n{msg}\n{'=' * len(msg)}")
exit(1)
else:
raise

raise

class BuildPyCommand(build_py.build_py):
def run(self) -> None:
if not self.dry_run:
self.run_command("clean")
run_antlr(self)
build_py.build_py.run(self)
super().run()


class Develop(develop.develop):
def run(self) -> None: # type: ignore
class DevelopCommand(develop.develop):
def run(self) -> None:
if not self.dry_run:
run_antlr(self)
develop.develop.run(self)

super().run()

class SDistCommand(sdist.sdist):
def run(self) -> None:
if not self.dry_run: # type: ignore
if not self.dry_run:
self.run_command("clean")
run_antlr(self)
sdist.sdist.run(self)
super().run()


class ANTLRCommand(Command): # type: ignore
class ANTLRCommand(Command):
"""Generate parsers using ANTLR."""

description = "Run ANTLR"
user_options: List[str] = []

def run(self) -> None:
"""Run command."""
"""Run the ANTLR command to generate parsers."""
root_dir = abspath(dirname(__file__))
project_root = abspath(dirname(basename(__file__)))
for grammar in [
project_root = abspath(dirname(root_dir))
grammars = [
"hydra/grammar/OverrideLexer.g4",
"hydra/grammar/OverrideParser.g4",
]:
]
for grammar in grammars:
command = [
"java",
"-jar",
Expand All @@ -199,7 +184,6 @@ def run(self) -> None:
log.info(f"Generating parser for Python3: {command}")

subprocess.check_call(command)

log.info("Replacing imports of antlr4 in generated parsers")
self._fix_imports()

Expand All @@ -210,27 +194,28 @@ def finalize_options(self) -> None:
pass

def _fix_imports(self) -> None:
"""Fix imports from the generated parsers to use the vendored antlr4 instead"""
"""Fix imports in the generated parsers to use the vendored antlr4."""
build_dir = Path(__file__).parent.absolute()
project_root = build_dir.parent
lib = "antlr4"
pkgname = 'omegaconf.vendor'

replacements = [
partial( # import antlr4 -> import omegaconf.vendor.antlr4
re.compile(r'(^\s*)import {}\n'.format(lib), flags=re.M).sub,
r'\1from {} import {}\n'.format(pkgname, lib)
partial(
re.compile(rf'(^\s*)import {lib}\n', flags=re.M).sub,
rf'\1from {pkgname} import {lib}\n'
),
partial( # from antlr4 -> from fomegaconf.vendor.antlr4
re.compile(r'(^\s*)from {}(\.|\s+)'.format(lib), flags=re.M).sub,
r'\1from {}.{}\2'.format(pkgname, lib)
partial(
re.compile(rf'(^\s*)from {lib}(\.|\s+)', flags=re.M).sub,
rf'\1from {pkgname}.{lib}\2'
),
]

path = project_root / "hydra" / "grammar" / "gen"
for item in path.iterdir():
if item.is_file() and item.name.endswith(".py"):
text = item.read_text('utf8')
gen_path = project_root / "hydra" / "grammar" / "gen"
for item in gen_path.iterdir():
if item.is_file() and item.suffix == ".py":
text = item.read_text('utf-8')
for replacement in replacements:
text = replacement(text)
item.write_text(text, 'utf8')
item.write_text(text, 'utf-8')

Loading