Skip to content

Commit 98f9a97

Browse files
authored
Merge pull request #51 from allenai/favyen/issue1
Ignore files when creating code archive to launch Beaker job
2 parents 5b51777 + 5c16ffb commit 98f9a97

File tree

2 files changed

+74
-3
lines changed

2 files changed

+74
-3
lines changed

rslp/launcher_lib.py

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,14 @@
44
import os
55
import shutil
66
import tempfile
7+
import zipfile
78

89
import yaml
910
from google.cloud import storage
1011

1112
CODE_BLOB_PATH = "projects/{project_id}/{experiment_id}/code.zip"
1213
WANDB_ID_BLOB_PATH = "projects/{project_id}/{experiment_id}/wandb_id"
14+
CODE_EXCLUDES = [".env", "wandb", "rslp/__pycache__"]
1315

1416
bucket = None
1517

@@ -38,6 +40,40 @@ def get_project_and_experiment(config_path: str) -> tuple[str, str]:
3840
return project_id, experiment_id
3941

4042

43+
def make_archive(
44+
zip_filename: str, root_dir: str, exclude_prefixes: list[str] = []
45+
) -> None:
46+
"""Create a zip archive of the contents of root_dir.
47+
48+
The paths in the zip archive will be relative to root_dir.
49+
50+
This is similar to shutil.make_archive but it allows specifying a list of prefixes
51+
that should not be added to the zip archive.
52+
53+
Args:
54+
zip_filename: the filename to save archive under.
55+
root_dir: the directory to create archive of.
56+
exclude_prefixes: a list of prefixes to exclude from the archive. If the
57+
relative path of a file from root_dir starts with one of the prefixes, then
58+
it will not be added to the resulting archive.
59+
"""
60+
61+
def should_exclude(rel_path: str) -> bool:
62+
for prefix in exclude_prefixes:
63+
if rel_path.startswith(prefix):
64+
return True
65+
return False
66+
67+
with zipfile.ZipFile(zip_filename, "w", zipfile.ZIP_DEFLATED) as zipf:
68+
for root, _, files in os.walk(root_dir):
69+
for fname in files:
70+
full_path = os.path.join(root, fname)
71+
rel_path = os.path.relpath(full_path, start=root_dir)
72+
if should_exclude(rel_path):
73+
continue
74+
zipf.write(full_path, arcname=rel_path)
75+
76+
4177
def upload_code(project_id: str, experiment_id: str) -> None:
4278
"""Upload code to GCS that entrypoint should retrieve.
4379
@@ -50,8 +86,11 @@ def upload_code(project_id: str, experiment_id: str) -> None:
5086
bucket = _get_bucket()
5187
with tempfile.TemporaryDirectory() as tmpdirname:
5288
print("creating archive of current code state")
53-
zip_fname = shutil.make_archive(
54-
os.path.join(tmpdirname, "archive"), "zip", root_dir="."
89+
zip_fname = os.path.join(tmpdirname, "archive.zip")
90+
make_archive(
91+
zip_fname,
92+
root_dir=".",
93+
exclude_prefixes=CODE_EXCLUDES,
5594
)
5695
print("uploading archive")
5796
blob_path = CODE_BLOB_PATH.format(
@@ -73,7 +112,7 @@ def download_code(project_id: str, experiment_id: str) -> None:
73112
"""
74113
bucket = _get_bucket()
75114
with tempfile.TemporaryDirectory() as tmpdirname:
76-
print("downloading code acrhive")
115+
print("downloading code archive")
77116
blob_path = CODE_BLOB_PATH.format(
78117
project_id=project_id, experiment_id=experiment_id
79118
)

tests/unit/test_launcher_lib.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import pathlib
2+
import zipfile
3+
4+
from rslp.launcher_lib import make_archive
5+
6+
7+
def test_make_archive(tmp_path: pathlib.Path) -> None:
8+
# Make sure make_archive correctly ignores the passed prefixes.
9+
# We make sure it works with exactly matching file as well as a subdirectory.
10+
exclude_prefixes = [
11+
"ignored_file",
12+
"dir/ignored_subdir",
13+
]
14+
root_dir = tmp_path / "root"
15+
(root_dir / "dir" / "ignored_subdir").mkdir(parents=True, exist_ok=True)
16+
(root_dir / "dir" / "okay_subdir").mkdir(parents=True, exist_ok=True)
17+
(root_dir / "okay_file1").touch()
18+
(root_dir / "ignored_file").touch()
19+
(root_dir / "dir" / "okay_file2").touch()
20+
(root_dir / "dir" / "ignored_subdir" / "also_ignored").touch()
21+
(root_dir / "dir" / "okay_subdir" / "okay_file3").touch()
22+
23+
zip_fname = str(tmp_path / "archive.zip")
24+
25+
make_archive(zip_fname, str(root_dir), exclude_prefixes=exclude_prefixes)
26+
with zipfile.ZipFile(zip_fname) as zipf:
27+
members = zipf.namelist()
28+
29+
assert "okay_file1" in members
30+
assert "dir/okay_file2" in members
31+
assert "dir/okay_subdir/okay_file3" in members
32+
assert len(members) == 3

0 commit comments

Comments
 (0)