Skip to content

Commit de5cd7e

Browse files
committed
Ignore .env and wandb and other files during code archive creation.
Resolves #1.
1 parent 72100cb commit de5cd7e

File tree

1 file changed

+39
-2
lines changed

1 file changed

+39
-2
lines changed

rslp/launcher_lib.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,14 @@
44
import os
55
import shutil
66
import tempfile
7+
import zipfile
78

89
import yaml
910
from google.cloud import storage
1011

1112
CODE_BLOB_PATH = "projects/{project_id}/{experiment_id}/code.zip"
1213
WANDB_ID_BLOB_PATH = "projects/{project_id}/{experiment_id}/wandb_id"
14+
CODE_EXCLUDES = [".env", "wandb", "rslp/__pycache__"]
1315

1416
bucket = None
1517

@@ -38,6 +40,38 @@ def get_project_and_experiment(config_path: str) -> tuple[str, str]:
3840
return project_id, experiment_id
3941

4042

43+
def make_archive(zip_filename: str, root_dir: str, exclude_prefixes: list[str] = []):
44+
"""Create a zip archive of the contents of root_dir.
45+
46+
The paths in the zip archive will be relative to root_dir.
47+
48+
This is similar to shutil.make_archive but it allows specifying a list of prefixes
49+
that should not be added to the zip archive.
50+
51+
Args:
52+
zip_filename: the filename to save archive under.
53+
root_dir: the directory to create archive of.
54+
exclude_prefixes: a list of prefixes to exclude from the archive. If the
55+
relative path of a file from root_dir starts with one of the prefixes, then
56+
it will not be added to the resulting archive.
57+
"""
58+
59+
def should_exclude(rel_path: str) -> bool:
60+
for prefix in exclude_prefixes:
61+
if rel_path.startswith(prefix):
62+
return True
63+
return False
64+
65+
with zipfile.ZipFile(zip_filename, "w", zipfile.ZIP_DEFLATED) as zipf:
66+
for root, _, files in os.walk(root_dir):
67+
for fname in files:
68+
full_path = os.path.join(root, fname)
69+
rel_path = os.path.relpath(full_path, start=root_dir)
70+
if should_exclude(rel_path):
71+
continue
72+
zipf.write(full_path, arcname=rel_path)
73+
74+
4175
def upload_code(project_id: str, experiment_id: str):
4276
"""Upload code to GCS that entrypoint should retrieve.
4377
@@ -50,8 +84,11 @@ def upload_code(project_id: str, experiment_id: str):
5084
bucket = _get_bucket()
5185
with tempfile.TemporaryDirectory() as tmpdirname:
5286
print("creating archive of current code state")
53-
zip_fname = shutil.make_archive(
54-
os.path.join(tmpdirname, "archive"), "zip", root_dir="."
87+
zip_fname = os.path.join(tmpdirname, "archive.zip")
88+
make_archive(
89+
zip_fname,
90+
root_dir=".",
91+
exclude_prefixes=CODE_EXCLUDES,
5592
)
5693
print("uploading archive")
5794
blob_path = CODE_BLOB_PATH.format(

0 commit comments

Comments
 (0)