Skip to content

Commit

Permalink
Forward additional aot flags, add linting actions.
Browse files Browse the repository at this point in the history
  • Loading branch information
riga committed Mar 27, 2024
1 parent 67ff516 commit 9b09c3c
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 3 deletions.
22 changes: 22 additions & 0 deletions .github/workflows/lint_and_test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
name: Lint and test

on:
workflow_dispatch:
push:

jobs:
linting:
runs-on: ubuntu-latest
steps:
- name: Checkout 🛎️
uses: actions/checkout@v4
with:
persist-credentials: false

- name: Install dependencies ☕️
run: |
pip install -U pip setuptools
pip install .[dev]
- name: Lint 🔍
run: flake8 cms_tfaot
7 changes: 6 additions & 1 deletion cms_tfaot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,11 @@ def load_and_normalize_config(config_file: str) -> dict[str, Any]:
return config


def compile_model(config: dict[str, Any], output_dir: str) -> tuple[list[str], list[str]]:
def compile_model(
config: dict[str, Any],
output_dir: str,
additional_flags: list[str] | str | None = None,
) -> tuple[list[str], list[str]]:
from cmsml.scripts.compile_tf_graph import compile_tf_graph

with tempfile.TemporaryDirectory() as tmp_dir:
Expand All @@ -118,6 +122,7 @@ def compile_model(config: dict[str, Any], output_dir: str) -> tuple[list[str], l
compile_class=f"{config['compilation']['namespace']}::{config['model']['name']}_bs{{}}",
xla_flags=config["compilation"].get("xla_flags"),
tf_xla_flags=config["compilation"].get("tf_xla_flags"),
additional_flags=additional_flags,
)
aot_dir = os.path.join(tmp_dir, "aot")

Expand Down
12 changes: 11 additions & 1 deletion cms_tfaot/scripts/tfaot_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def tfaot_compile(
tool_name: str | None = None,
tool_base: str | None = None,
dev: bool = False,
additional_flags: list[str] | str | None = None,
) -> CompilationResult:
# deferred imports
from cms_tfaot import load_and_normalize_config, compile_model, create_wrapper, create_toolfile
Expand All @@ -54,7 +55,11 @@ def tfaot_compile(
tool_base = cmssw_rel_path(output_dir) if dev else "@TOOL_BASE@"

# compile
header_files, object_files = compile_model(config, output_dir)
header_files, object_files = compile_model(
config,
output_dir,
additional_flags=additional_flags,
)

# create subdirectories and move files in dev mode
header_dir = output_dir
Expand Down Expand Up @@ -176,6 +181,10 @@ def main() -> int:
action="store_true",
help="activates the development workflow, setting some variables to sensible defaults",
)
parser.add_argument(
"--additional-flags",
help="additional flags to be passed to the underlying aot compiler invocation",
)
args = parser.parse_args()

tfaot_compile(
Expand All @@ -184,6 +193,7 @@ def main() -> int:
tool_name=args.tool_name,
tool_base=args.tool_base,
dev=args.dev,
additional_flags=args.additional_flags,
)

return 0
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ authors = [
]
license = {file = "LICENSE"}
requires-python = ">=3.7"
dynamic = ["version", "readme", "dependencies"]
dynamic = ["version", "readme", "dependencies", "optional-dependencies"]


[project.scripts]
Expand All @@ -26,6 +26,7 @@ cms_tfaot_compile = "cms_tfaot.scripts.tfaot_compile:main"
version = {attr = "cms_tfaot.__meta__.__version__"}
readme = {file = ["README.md"], content-type = "text/markdown"}
dependencies = {file = ["requirements.txt"]}
optional-dependencies = {dev = {file = ["requirements_dev.txt"]}}


[tool.setuptools.packages.find]
Expand Down
5 changes: 5 additions & 0 deletions requirements_dev.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
flake8~=5.0
flake8-commas~=2.1
flake8-quotes~=3.3
pytest-cov>=3.0
pytest-xdist~=3.4.0

0 comments on commit 9b09c3c

Please sign in to comment.