Skip to content

Commit

Permalink
Optimize performance of code optimizer
Browse files Browse the repository at this point in the history
Ref. eng/recordflux/RecordFlux#1525
  • Loading branch information
treiher committed Feb 28, 2024
1 parent a821cfd commit 9217696
Show file tree
Hide file tree
Showing 9 changed files with 15,183 additions and 73 deletions.
4 changes: 3 additions & 1 deletion doc/user_guide/90-rflx-generate--help.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ usage: rflx generate [-h] [-p PREFIX] [-n] [-d OUTPUT_DIRECTORY]
[--debug {built-in,external}]
[--ignore-unsupported-checksum]
[--integration-files-dir INTEGRATION_FILES_DIR]
[--reproducible] [--optimize]
[--reproducible] [--optimize] [--timeout TIMEOUT]
[SPECIFICATION_FILE ...]

positional arguments:
Expand All @@ -23,3 +23,5 @@ options:
--reproducible ensure reproducible output
--optimize optimize generated state machine code (requires
GNATprove)
--timeout TIMEOUT prover timeout in seconds for code optimization
(default: 1)
7 changes: 4 additions & 3 deletions doc/user_guide/90-rflx-optimize--help.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
usage: rflx optimize [-h] DIRECTORY
usage: rflx optimize [-h] [--timeout TIMEOUT] DIRECTORY

positional arguments:
DIRECTORY directory containing the generated code
DIRECTORY directory containing the generated code

options:
-h, --help show this help message and exit
-h, --help show this help message and exit
--timeout TIMEOUT prover timeout in seconds (default: 1)
16 changes: 14 additions & 2 deletions rflx/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,12 @@ def main( # noqa: PLR0915
action="store_true",
help="optimize generated state machine code (requires GNATprove)",
)
parser_generate.add_argument(
"--timeout",
type=int,
default=1,
help="prover timeout in seconds for code optimization (default: %(default)s)",
)
parser_generate.add_argument(
"files",
metavar="SPECIFICATION_FILE",
Expand All @@ -252,6 +258,12 @@ def main( # noqa: PLR0915
"optimize",
help="optimize generated state machine code",
)
parser_optimize.add_argument(
"--timeout",
type=int,
default=1,
help="prover timeout in seconds (default: %(default)s)",
)
parser_optimize.add_argument(
"directory",
metavar="DIRECTORY",
Expand Down Expand Up @@ -553,14 +565,14 @@ def generate(args: argparse.Namespace) -> None:
)

if args.optimize:
optimizer.optimize(args.output_directory, args.workers)
optimizer.optimize(args.output_directory, args.workers, args.timeout)


def optimize(args: argparse.Namespace) -> None:
if not args.directory.is_dir():
fail(f'directory not found: "{args.directory}"', Subsystem.CLI)

optimizer.optimize(args.directory, args.workers)
optimizer.optimize(args.directory, args.workers, args.timeout)


def parse(
Expand Down
119 changes: 85 additions & 34 deletions rflx/generator/optimizer.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
from __future__ import annotations

import json
import logging
import re
import shutil
from collections import defaultdict
import tempfile
from collections.abc import Iterable
from dataclasses import dataclass
from pathlib import Path
from subprocess import PIPE, STDOUT, run
from tempfile import TemporaryDirectory

from rflx.error import Subsystem, fail
from rflx.spark import SPARKFile

log = logging.getLogger(__name__)

Expand All @@ -28,36 +30,36 @@ def valid(self) -> bool:
return 0 < self.begin < self.assertion < self.end


def optimize(generated_dir: Path, workers: int = 0) -> None:
def optimize(generated_dir: Path, workers: int = 0, timeout: int = 1) -> None:
"""Remove unnecessary checks in generated state machine code."""

if not gnatprove_found():
fail("GNATprove is required for code optimization", Subsystem.GENERATOR)

with TemporaryDirectory() as tmp_dir_name:
tmp_dir = Path(tmp_dir_name)
shutil.copytree(generated_dir, tmp_dir, dirs_exist_ok=True)

checks: dict[Path, dict[int, Check]] = defaultdict(dict)
checks: dict[Path, dict[int, Check]] = {}

for f in tmp_dir.glob("*.adb"):
checks[f] = instrument(f)

i = 1
total = sum(1 for cs in checks.values() for c in cs)
checks_to_remove: dict[Path, dict[int, Check]] = defaultdict(dict)
cs = instrument(f)
if cs:
checks[f] = cs

for f, cs in checks.items():
for assertion, check in cs.items():
log.info("Analyzing %s (%d/%d)", f.name, i, total)
checks_to_remove[f] |= analyze(f, {assertion: check}, workers)
i += 1
checks_to_remove: dict[Path, dict[int, Check]] = {}

for f in checks_to_remove:
remove(f, checks_to_remove[f], checks[f])
shutil.copy(f, generated_dir)
for i, (f, cs) in enumerate(checks.items(), start=1):
log.info("Analyzing %s (%d/%d)", f.name, i, len(checks))
checks_to_remove[f] = analyze(f, cs, workers, timeout)
if checks_to_remove[f]:
remove(f, checks_to_remove[f], checks[f])
shutil.copy(f, generated_dir)

log.info(
"Optimization completed: %d checks removed",
"Optimization completed: %d/%d checks removed",
sum(1 for cs in checks_to_remove.values() for c in cs),
sum(1 for cs in checks.values() for c in cs),
)


Expand All @@ -66,6 +68,7 @@ def gnatprove_found() -> bool:


def instrument(file: Path) -> dict[int, Check]:
"""Add an always false assertion before each goto statement."""
checks: dict[int, Check] = {}
content = file.read_text()
instrumented_content = ""
Expand Down Expand Up @@ -96,40 +99,88 @@ def instrument(file: Path) -> dict[int, Check]:
return checks


def analyze(file: Path, checks: dict[int, Check], workers: int = 0) -> dict[int, Check]:
def analyze(
file: Path,
checks: dict[int, Check],
workers: int = 0,
timeout: int = 1,
) -> dict[int, Check]:
"""Analyze file and return removable checks."""

result: dict[int, Check] = {}

if not checks:
return result

for i in checks:
if prove(file, i, workers):
result[i] = checks[i]
proof_results = prove(file, workers, timeout)

for line in checks:
if proof_results[line]:
result[line] = checks[line]

return result


def prove(file: Path, line: int, workers: int = 0) -> bool:
return (
run(
def prove(file: Path, workers: int = 0, timeout: int = 1) -> dict[int, bool]:
"""Prove file and return results for all assertions."""

with tempfile.TemporaryDirectory() as tmp_dir_name:
tmp_dir = Path(tmp_dir_name)
project_file = tmp_dir / "optimize.gpr"
project_file.write_text(
f'project Optimize is\n for Source_Dirs use ("{file.parent}");\nend Optimize;\n',
)

p = run(
[
"gnatprove",
f"--limit-line={file.name}:{line}",
f"-j{workers}",
"--prover=all",
"--timeout=1",
"--checks-as-errors=on",
"--quiet",
"-P",
str(project_file),
"-u",
file.name,
"-j",
str(workers),
"--prover=z3,cvc5",
"--timeout",
str(timeout),
],
cwd=file.parent,
stdout=PIPE,
stderr=STDOUT,
).returncode
== 0
)
)

if p.returncode != 0:
fail(
f"gnatprove terminated with exit code {p.returncode}"
+ ("\n" + p.stdout.decode("utf-8") if p.stdout is not None else ""),
Subsystem.GENERATOR,
)

return get_proof_results_for_asserts(
project_file.parent / "gnatprove" / file.with_suffix(".spark").name,
)


def get_proof_results_for_asserts(spark_file: Path) -> dict[int, bool]:
result = {}

for proof in SPARKFile(**json.loads(spark_file.read_text())).proof:
if proof.rule != "VC_ASSERT":
continue

for attempt in proof.check_tree[0].proof_attempts.values():
if attempt.result == "Valid":
result[proof.line] = True
break
else:
result[proof.line] = False

return result


def remove(file: Path, checks: dict[int, Check], assertions: Iterable[int]) -> None:
"""Remove checks and always false assertions."""

lines = {i for c in checks.values() for i in range(c.begin, c.end + 1)} | set(assertions)
content = file.read_text()
optimized_content = ""
Expand All @@ -139,4 +190,4 @@ def remove(file: Path, checks: dict[int, Check], assertions: Iterable[int]) -> N
optimized_content += f"{l}\n"

assert "pragma Assert (False)" not in optimized_content
file.write_text(optimized_content.strip())
file.write_text(optimized_content.strip() + "\n")
27 changes: 27 additions & 0 deletions rflx/spark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from __future__ import annotations

import typing as ty

from pydantic import BaseModel


class ProofAttempt(BaseModel):
result: str
steps: int
time: int


class CheckTreeElement(BaseModel):
proof_attempts: ty.Mapping[str, ProofAttempt]


class Proof(BaseModel):
file: str
line: int
col: int
rule: str
check_tree: ty.Sequence[CheckTreeElement]


class SPARKFile(BaseModel):
proof: ty.Sequence[Proof]
Loading

0 comments on commit 9217696

Please sign in to comment.