Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 19 additions & 19 deletions minion/utils/syncheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import astunparse


def syntax_check(code, verbose=False):
def syntax_check(code, verbose=False) -> None:
try:
ast.parse(code)
return True
Expand All @@ -22,9 +22,9 @@ def syntax_check(code, verbose=False):
class TimeoutError(Exception):
pass

def run_with_timeout(func, args, timeout=None):
def run_with_timeout(func, args, timeout=None) -> None:
result = []
def target():
def target() -> None:
try:
result.append(func(*args))
except Exception as e:
Expand Down Expand Up @@ -76,7 +76,7 @@ def script(
dataset = get_mbpp_plus()
dataset_name = "Mbpp"

print(colored(f"Dataset: {dataset_name}", "blue"))
logging.info(colored(f"Dataset: {dataset_name}", "blue"))

id2solutions = {}
for solution in solutions:
Expand All @@ -88,24 +88,24 @@ def script(
solution["solution"] = dataset[task_id]["prompt"] + solution["completion"]
id2solutions[task_id].append(solution)

print(colored("==============================", "blue"))
print(colored(" ::: Checking completeness... ", "blue"))
print(colored(" ::::: All tasks complete? ", "blue"))
logging.info(colored("==============================", "blue"))
logging.info(colored(" ::: Checking completeness... ", "blue"))
logging.info(colored(" ::::: All tasks complete? ", "blue"))
ndone = 0

task_ids = dataset.keys()
ntask = len(task_ids)
for task_id in task_ids:
if task_id not in id2solutions:
print(colored(f" ⚠️ {task_id} is missing!", "red"))
logging.info(colored(f" ⚠️ {task_id} is missing!", "red"))
continue
nfiles = len(id2solutions[task_id])

if nsample_check is None or nfiles <= nsample_check:
ndone += 1
continue

print(
logging.info(
colored(
f" ⚠️ {task_id} only has {nfiles} samples! But {nsample_check} are expected.",
"red",
Expand All @@ -116,13 +116,13 @@ def script(
if nsample_check is not None:
if ntask != ndone:
ntbd = ntask - ndone
print(colored(f" ::::: ⚠️ {ntbd}/{ntask} tasks incomplete!", "red"))
logging.info(colored(f" ::::: ⚠️ {ntbd}/{ntask} tasks incomplete!", "red"))
else:
print(colored(f" ::::: All {ntask} tasks complete!", "green"))
logging.info(colored(f" ::::: All {ntask} tasks complete!", "green"))

print(colored("==============================", "blue"))
print(colored(" ::: Checking compilation... ", "blue"))
print(colored(" ::::: All code compilable? ", "blue"))
logging.info(colored("==============================", "blue"))
logging.info(colored(" ::: Checking compilation... ", "blue"))
logging.info(colored(" ::::: All code compilable? ", "blue"))
ncode = 0
nwrong = 0
for task_id in task_ids:
Expand All @@ -135,18 +135,18 @@ def script(
code = solution["solution"]
dbg_identifier = solution["_identifier"]
if code.strip() == "":
print(colored(f" ⚠️ {dbg_identifier} is empty!", "red"))
logging.info(colored(f" ⚠️ {dbg_identifier} is empty!", "red"))
nwrong += 1
elif not syntax_check(code, verbose):
print(colored(f" ⚠️ {dbg_identifier} is not compilable!", "red"))
logging.info(colored(f" ⚠️ {dbg_identifier} is not compilable!", "red"))
nwrong += 1
if 0 != nwrong:
print(colored(f" ::::: ⚠️ {nwrong}/{ncode} code are not compilable!", "red"))
logging.info(colored(f" ::::: ⚠️ {nwrong}/{ncode} code are not compilable!", "red"))
else:
print(colored(f" ::::: All {ncode} code are compilable!", "green"))
logging.info(colored(f" ::::: All {ncode} code are compilable!", "green"))


def main():
def main() -> None:
from fire import Fire

Fire(script)
Expand Down
Loading