Skip to content

Commit

Permalink
Improve gitutils.py by using the Github CLI when available (#1088)
Browse files Browse the repository at this point in the history
The `gitutils.py` script often fails when you are working in a PR, or if there has been a commit to your target branch since you created your current branch.

This improves the utility by using the Github CLI to query the PR information when possible and get the true target branch for the current PR.

If the Github CLI is not available, or the branch isnt in a PR, it falls back to the previous functionality.

Authors:
  - Michael Demoret (https://github.com/mdemoret-nv)

Approvers:
  - David Gardner (https://github.com/dagardner-nv)

URL: #1088
  • Loading branch information
mdemoret-nv authored Aug 17, 2023
1 parent 70cceef commit aba421e
Show file tree
Hide file tree
Showing 4 changed files with 632 additions and 318 deletions.
184 changes: 95 additions & 89 deletions ci/scripts/copyright.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
# set up imports
import gitutils # noqa: E402

# pylint: disable=global-statement

FilesToCheck = [
# Get all of these extensions and templates (*.in)
re.compile(r"[.](cmake|cpp|cc|cu|cuh|h|hpp|md|rst|sh|pxd|py|pyx|yml|yaml)(\.in)?$"),
Expand All @@ -41,26 +43,29 @@

# Nothing in a build folder or .cache
ExemptFiles: typing.List[re.Pattern] = [
r"(_version|versioneer)\.py", # Skip versioning files
r"^[^ \/\n]*\.cache[^ \/\n]*\/.*$", # Ignore .cache folder
r"^[^ \/\n]*build[^ \/\n]*\/.*$", # Ignore any build*/ folder
r"^external\/.*$", # Ignore external
r"[^ \/\n]*docs/source/(_lib|_modules|_templates)/.*$",
r"PULL_REQUEST_TEMPLATE.md" # Ignore the PR template
re.compile(r"(_version|versioneer)\.py"), # Skip versioning files
re.compile(r"^[^ \/\n]*\.cache[^ \/\n]*\/.*$"), # Ignore .cache folder
re.compile(r"^[^ \/\n]*build[^ \/\n]*\/.*$"), # Ignore any build*/ folder
re.compile(r"^external\/.*$"), # Ignore external
re.compile(r"[^ \/\n]*docs/source/(_lib|_modules|_templates)/.*$"),
re.compile(r"PULL_REQUEST_TEMPLATE.md"), # Ignore the PR template
]

# this will break starting at year 10000, which is probably OK :)
CheckSimple = re.compile(r"Copyright *(?:\(c\))? *(\d{4}),? *NVIDIA C(?:ORPORATION|orporation)")
CheckDouble = re.compile(r"Copyright *(?:\(c\))? *(\d{4})-(\d{4}),? *NVIDIA C(?:ORPORATION|orporation)" # noqa: E501
)
CheckApacheLic = 'Licensed under the Apache License, Version 2.0 (the "License");'
CheckDouble = re.compile(r"Copyright *(?:\(c\))? *(\d{4})-(\d{4}),? *NVIDIA C(?:ORPORATION|orporation)")
CHECK_APACHE_LIC = 'Licensed under the Apache License, Version 2.0 (the "License");'


def is_file_empty(f):
return os.stat(f).st_size == 0


def checkThisFile(f):
def check_this_file(f):
# This check covers things like symlinks which point to files that DNE
if not (os.path.exists(f)):
return False
if gitutils and gitutils.isFileEmpty(f):
if is_file_empty(f):
return False
for exempt in ExemptFiles:
if exempt.search(f):
Expand All @@ -71,7 +76,7 @@ def checkThisFile(f):
return False


def getCopyrightYears(line):
def get_copyright_years(line):
res = CheckSimple.search(line)
if res:
return (int(res.group(1)), int(res.group(1)))
Expand All @@ -81,14 +86,16 @@ def getCopyrightYears(line):
return (None, None)


def replaceCurrentYear(line, start, end):
def replace_current_year(line, start, end):
# first turn a simple regex into double (if applicable). then update years
res = CheckSimple.sub(r"Copyright (c) \1-\1, NVIDIA CORPORATION", line)

# pylint: disable=consider-using-f-string
res = CheckDouble.sub(r"Copyright (c) {:04d}-{:04d}, NVIDIA CORPORATION".format(start, end), res)
return res


def insertLicense(f, this_year, first_line):
def insert_license(f, this_year, first_line):
ext = os.path.splitext(f)[1].lstrip('.')

try:
Expand All @@ -97,87 +104,87 @@ def insertLicense(f, this_year, first_line):
return [
f,
0,
"Unsupported extension {} for automatic insertion, "
f"Unsupported extension {ext} for automatic insertion, "
"please manually insert an Apache v2.0 header or add the file to "
"excempted from this check add it to the 'ExemptFiles' list in "
"the 'ci/scripts/copyright.py' file (manual fix required)".format(ext),
"the 'ci/scripts/copyright.py' file (manual fix required)",
None
]

# If the file starts with a #! keep it as the first line
if first_line.startswith("#!"):
replace_line = first_line + license_text
else:
replace_line = "{}\n{}".format(license_text, first_line)
replace_line = f"{license_text}\n{first_line}"

return [f, 1, "License inserted", replace_line]


def checkCopyright(f,
update_current_year,
verify_apache_v2=False,
update_start_year=False,
insert_license=False,
git_add=False):
def check_copyright(f,
update_current_year,
verify_apache_v2=False,
update_start_year=False,
do_insert_license=False,
git_add=False):
"""
Checks for copyright headers and their years
"""
errs = []
thisYear = datetime.datetime.now().year
lineNum = 0
crFound = False
apacheLicFound = not verify_apache_v2
yearMatched = False
with io.open(f, "r", encoding="utf-8") as fp:
lines = fp.readlines()
this_year = datetime.datetime.now().year
line_num = 0
cr_found = False
apache_lic_found = not verify_apache_v2
year_matched = False
with io.open(f, "r", encoding="utf-8") as file:
lines = file.readlines()
for line in lines:
lineNum += 1
if not apacheLicFound:
apacheLicFound = CheckApacheLic in line
line_num += 1
if not apache_lic_found:
apache_lic_found = CHECK_APACHE_LIC in line

start, end = getCopyrightYears(line)
start, end = get_copyright_years(line)
if start is None:
continue

crFound = True
cr_found = True
if update_start_year:
try:
git_start = gitutils.determine_add_date(f).year
git_start = gitutils.get_file_add_date(f).year
if start > git_start:
e = [
f,
lineNum,
line_num,
"Current year not included in the "
"copyright header",
replaceCurrentYear(line, git_start, thisYear)
replace_current_year(line, git_start, this_year)
]
errs.append(e)
continue

except Exception as excp:
e = [f, lineNum, "Error determining start year from git: {}".format(excp), None]
e = [f, line_num, f"Error determining start year from git: {excp}", None]
errs.append(e)
continue

if start > end:
e = [f, lineNum, "First year after second year in the copyright header (manual fix required)", None]
e = [f, line_num, "First year after second year in the copyright header (manual fix required)", None]
errs.append(e)
if thisYear < start or thisYear > end:
e = [f, lineNum, "Current year not included in the copyright header", None]
if thisYear < start:
e[-1] = replaceCurrentYear(line, thisYear, end)
if thisYear > end:
e[-1] = replaceCurrentYear(line, start, thisYear)
if this_year < start or this_year > end:
e = [f, line_num, "Current year not included in the copyright header", None]
if this_year < start:
e[-1] = replace_current_year(line, this_year, end)
if this_year > end:
e[-1] = replace_current_year(line, start, this_year)
errs.append(e)
else:
yearMatched = True
fp.close()

if not apacheLicFound:
if insert_license and len(lines):
e = insertLicense(f, thisYear, lines[0])
crFound = True
yearMatched = True
year_matched = True
file.close()

if not apache_lic_found:
if do_insert_license and len(lines):
e = insert_license(f, this_year, lines[0])
cr_found = True
year_matched = True
else:
e = [
f,
Expand All @@ -190,44 +197,43 @@ def checkCopyright(f,
errs.append(e)

# copyright header itself not found
if not crFound:
if not cr_found:
e = [f, 0, "Copyright header missing or formatted incorrectly (manual fix required)", None]
errs.append(e)

# even if the year matches a copyright header, make the check pass
if yearMatched and apacheLicFound:
if year_matched and apache_lic_found:
errs = []

if update_current_year or update_start_year or insert_license:
if update_current_year or update_start_year or do_insert_license:
errs_update = [x for x in errs if x[-1] is not None]
if len(errs_update) > 0:
logging.info("File: {}. Changing line(s) {}".format(f,
', '.join(str(x[1]) for x in errs
if x[-1] is not None)))
for _, lineNum, __, replacement in errs_update:
lines[lineNum - 1] = replacement
logging.info("File: %s. Changing line(s) %s", f, ', '.join(str(x[1]) for x in errs if x[-1] is not None))
for _, line_num, __, replacement in errs_update:
lines[line_num - 1] = replacement
with io.open(f, "w", encoding="utf-8") as out_file:
for new_line in lines:
out_file.write(new_line)

if git_add:
gitutils.add(f)
gitutils.add_files(f)

errs = [x for x in errs if x[-1] is None]

return errs


def checkCopyright_main():
def _main():
"""
Checks for copyright headers in all the modified files. In case of local
repo, this script will just look for uncommitted files and in case of CI
it compares between branches "$PR_TARGET_BRANCH" and "current-pr-branch"
"""
retVal = 0
global ExemptFiles
log_level = logging.getLevelName(os.environ.get("MORPHEUS_LOG_LEVEL", "INFO"))
logging.basicConfig(format="%(levelname)s:%(message)s", level=log_level)

logging.basicConfig(level=logging.DEBUG)
ret_val = 0
global ExemptFiles

argparser = argparse.ArgumentParser("Checks for a consistent copyright header in git's modified files")
argparser.add_argument("--update-start-year",
Expand Down Expand Up @@ -312,36 +318,35 @@ def checkCopyright_main():

(args, dirs) = argparser.parse_known_args()
try:
ExemptFiles = ExemptFiles + [pathName for pathName in args.exclude]
ExemptFiles = [re.compile(file) for file in ExemptFiles]
except re.error as reException:
logging.exception("Regular expression error: %s", reException, exc_info=True)
ExemptFiles = ExemptFiles + [re.compile(pathName) for pathName in args.exclude]
except re.error as re_exception:
logging.exception("Regular expression error: %s", re_exception, exc_info=True)
return 1

if args.git_modified_only:
files = gitutils.modifiedFiles()
files = gitutils.modified_files()
elif args.git_diff_commits:
files = gitutils.changedFilesBetweenCommits(*args.git_diff_commits)
files = gitutils.changed_files(*args.git_diff_commits)
elif args.git_diff_staged:
files = gitutils.stagedFiles(args.git_diff_staged)
files = gitutils.staged_files(args.git_diff_staged)
else:
files = gitutils.list_files_under_source_control(ref="HEAD", *dirs)
files = gitutils.all_files(*dirs)

logging.debug("File count before filter(): %s", len(files))

# Now filter the files down based on the exclude/include
files = gitutils.filter_files(files, path_filter=checkThisFile)
files = gitutils.filter_files(files, path_filter=check_this_file)

logging.info("Checking files (%s):\n %s", len(files), "\n ".join(files))

errors = []
for f in files:
errors += checkCopyright(f,
args.update_current_year,
verify_apache_v2=(args.verify_apache_v2 or args.insert or args.fix_all),
update_start_year=(args.update_start_year or args.fix_all),
insert_license=(args.insert or args.fix_all),
git_add=args.git_add)
errors += check_copyright(f,
args.update_current_year,
verify_apache_v2=(args.verify_apache_v2 or args.insert or args.fix_all),
update_start_year=(args.update_start_year or args.fix_all),
do_insert_license=(args.insert or args.fix_all),
git_add=args.git_add)

if len(errors) > 0:
logging.info("Copyright headers incomplete in some of the files!")
Expand All @@ -352,14 +357,16 @@ def checkCopyright_main():
path_parts = os.path.abspath(__file__).split(os.sep)
file_from_repo = os.sep.join(path_parts[path_parts.index("ci"):])
if n_fixable > 0:
logging.info(("You can run `python {} --git-modified-only "
"--update-current-year --insert` to fix {} of these "
"errors.\n").format(file_from_repo, n_fixable))
retVal = 1
logging.info(("You can run `python %s --git-modified-only "
"--update-current-year --insert` to fix %s of these "
"errors.\n"),
file_from_repo,
n_fixable)
ret_val = 1
else:
logging.info("Copyright check passed")

return retVal
return ret_val


A2_LIC_HASH = """# SPDX-FileCopyrightText: Copyright (c) {YEAR}, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Expand Down Expand Up @@ -454,5 +461,4 @@ def checkCopyright_main():
}

if __name__ == "__main__":
logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.INFO)
sys.exit(checkCopyright_main())
sys.exit(_main())
Loading

0 comments on commit aba421e

Please sign in to comment.