44import os
55import shutil
66import tempfile
7+ import zipfile
78
89import yaml
910from google .cloud import storage
1011
1112CODE_BLOB_PATH = "projects/{project_id}/{experiment_id}/code.zip"
1213WANDB_ID_BLOB_PATH = "projects/{project_id}/{experiment_id}/wandb_id"
14+ CODE_EXCLUDES = [".env" , "wandb" , "rslp/__pycache__" ]
1315
1416bucket = None
1517
@@ -38,6 +40,40 @@ def get_project_and_experiment(config_path: str) -> tuple[str, str]:
3840 return project_id , experiment_id
3941
4042
43+ def make_archive (
44+ zip_filename : str , root_dir : str , exclude_prefixes : list [str ] = []
45+ ) -> None :
46+ """Create a zip archive of the contents of root_dir.
47+
48+ The paths in the zip archive will be relative to root_dir.
49+
50+ This is similar to shutil.make_archive but it allows specifying a list of prefixes
51+ that should not be added to the zip archive.
52+
53+ Args:
54+ zip_filename: the filename to save archive under.
55+ root_dir: the directory to create archive of.
56+ exclude_prefixes: a list of prefixes to exclude from the archive. If the
57+ relative path of a file from root_dir starts with one of the prefixes, then
58+ it will not be added to the resulting archive.
59+ """
60+
61+ def should_exclude (rel_path : str ) -> bool :
62+ for prefix in exclude_prefixes :
63+ if rel_path .startswith (prefix ):
64+ return True
65+ return False
66+
67+ with zipfile .ZipFile (zip_filename , "w" , zipfile .ZIP_DEFLATED ) as zipf :
68+ for root , _ , files in os .walk (root_dir ):
69+ for fname in files :
70+ full_path = os .path .join (root , fname )
71+ rel_path = os .path .relpath (full_path , start = root_dir )
72+ if should_exclude (rel_path ):
73+ continue
74+ zipf .write (full_path , arcname = rel_path )
75+
76+
4177def upload_code (project_id : str , experiment_id : str ) -> None :
4278 """Upload code to GCS that entrypoint should retrieve.
4379
@@ -50,8 +86,11 @@ def upload_code(project_id: str, experiment_id: str) -> None:
5086 bucket = _get_bucket ()
5187 with tempfile .TemporaryDirectory () as tmpdirname :
5288 print ("creating archive of current code state" )
53- zip_fname = shutil .make_archive (
54- os .path .join (tmpdirname , "archive" ), "zip" , root_dir = "."
89+ zip_fname = os .path .join (tmpdirname , "archive.zip" )
90+ make_archive (
91+ zip_fname ,
92+ root_dir = "." ,
93+ exclude_prefixes = CODE_EXCLUDES ,
5594 )
5695 print ("uploading archive" )
5796 blob_path = CODE_BLOB_PATH .format (
@@ -73,7 +112,7 @@ def download_code(project_id: str, experiment_id: str) -> None:
73112 """
74113 bucket = _get_bucket ()
75114 with tempfile .TemporaryDirectory () as tmpdirname :
76- print ("downloading code acrhive " )
115+ print ("downloading code archive " )
77116 blob_path = CODE_BLOB_PATH .format (
78117 project_id = project_id , experiment_id = experiment_id
79118 )
0 commit comments