4
4
import os
5
5
import shutil
6
6
import tempfile
7
+ import zipfile
7
8
8
9
import yaml
9
10
from google .cloud import storage
10
11
11
12
CODE_BLOB_PATH = "projects/{project_id}/{experiment_id}/code.zip"
12
13
WANDB_ID_BLOB_PATH = "projects/{project_id}/{experiment_id}/wandb_id"
14
+ CODE_EXCLUDES = [".env" , "wandb" , "rslp/__pycache__" ]
13
15
14
16
bucket = None
15
17
@@ -38,6 +40,40 @@ def get_project_and_experiment(config_path: str) -> tuple[str, str]:
38
40
return project_id , experiment_id
39
41
40
42
43
+ def make_archive (
44
+ zip_filename : str , root_dir : str , exclude_prefixes : list [str ] = []
45
+ ) -> None :
46
+ """Create a zip archive of the contents of root_dir.
47
+
48
+ The paths in the zip archive will be relative to root_dir.
49
+
50
+ This is similar to shutil.make_archive but it allows specifying a list of prefixes
51
+ that should not be added to the zip archive.
52
+
53
+ Args:
54
+ zip_filename: the filename to save archive under.
55
+ root_dir: the directory to create archive of.
56
+ exclude_prefixes: a list of prefixes to exclude from the archive. If the
57
+ relative path of a file from root_dir starts with one of the prefixes, then
58
+ it will not be added to the resulting archive.
59
+ """
60
+
61
+ def should_exclude (rel_path : str ) -> bool :
62
+ for prefix in exclude_prefixes :
63
+ if rel_path .startswith (prefix ):
64
+ return True
65
+ return False
66
+
67
+ with zipfile .ZipFile (zip_filename , "w" , zipfile .ZIP_DEFLATED ) as zipf :
68
+ for root , _ , files in os .walk (root_dir ):
69
+ for fname in files :
70
+ full_path = os .path .join (root , fname )
71
+ rel_path = os .path .relpath (full_path , start = root_dir )
72
+ if should_exclude (rel_path ):
73
+ continue
74
+ zipf .write (full_path , arcname = rel_path )
75
+
76
+
41
77
def upload_code (project_id : str , experiment_id : str ) -> None :
42
78
"""Upload code to GCS that entrypoint should retrieve.
43
79
@@ -50,8 +86,11 @@ def upload_code(project_id: str, experiment_id: str) -> None:
50
86
bucket = _get_bucket ()
51
87
with tempfile .TemporaryDirectory () as tmpdirname :
52
88
print ("creating archive of current code state" )
53
- zip_fname = shutil .make_archive (
54
- os .path .join (tmpdirname , "archive" ), "zip" , root_dir = "."
89
+ zip_fname = os .path .join (tmpdirname , "archive.zip" )
90
+ make_archive (
91
+ zip_fname ,
92
+ root_dir = "." ,
93
+ exclude_prefixes = CODE_EXCLUDES ,
55
94
)
56
95
print ("uploading archive" )
57
96
blob_path = CODE_BLOB_PATH .format (
@@ -73,7 +112,7 @@ def download_code(project_id: str, experiment_id: str) -> None:
73
112
"""
74
113
bucket = _get_bucket ()
75
114
with tempfile .TemporaryDirectory () as tmpdirname :
76
- print ("downloading code acrhive " )
115
+ print ("downloading code archive " )
77
116
blob_path = CODE_BLOB_PATH .format (
78
117
project_id = project_id , experiment_id = experiment_id
79
118
)
0 commit comments