|
12 | 12 | # language governing permissions and limitations under the License. |
13 | 13 | from __future__ import absolute_import |
14 | 14 |
|
| 15 | +from collections import namedtuple |
| 16 | + |
15 | 17 | import os |
16 | 18 | import re |
| 19 | +import sagemaker.utils |
17 | 20 | import shutil |
18 | 21 | import tempfile |
19 | | -from collections import namedtuple |
20 | 22 | from six.moves.urllib.parse import urlparse |
21 | 23 |
|
22 | | -import sagemaker.utils |
23 | | - |
24 | 24 | _TAR_SOURCE_FILENAME = 'source.tar.gz' |
25 | 25 |
|
26 | 26 | UploadedCode = namedtuple('UserCode', ['s3_prefix', 'script_name']) |
@@ -112,46 +112,57 @@ def validate_source_dir(script, directory): |
112 | 112 |
|
113 | 113 |
|
114 | 114 | def tar_and_upload_dir(session, bucket, s3_key_prefix, script, directory, dependencies=None): |
115 | | - """Pack and upload source files to S3 only if directory is empty or local. |
| 115 | + """Package source files and upload a compress tar file to S3. The S3 location will be |
| 116 | + ``s3://<bucket>/s3_key_prefix/sourcedir.tar.gz``. |
| 117 | +
|
| 118 | + If directory is an S3 URI, an UploadedCode object will be returned, but nothing will be |
| 119 | + uploaded to S3 (this allow reuse of code already in S3). |
| 120 | +
|
| 121 | + If directory is None, the script will be added to the archive at ``./<basename of script>``. |
116 | 122 |
|
117 | | - Note: |
118 | | - If the directory points to S3 no action is taken. |
| 123 | + If directory is not None, the (recursive) contents of the directory will be added to |
| 124 | + the archive. directory is treated as the base path of the archive, and the script name is |
| 125 | + assumed to be a filename or relative path inside the directory. |
119 | 126 |
|
120 | 127 | Args: |
121 | 128 | session (boto3.Session): Boto session used to access S3. |
122 | 129 | bucket (str): S3 bucket to which the compressed file is uploaded. |
123 | 130 | s3_key_prefix (str): Prefix for the S3 key. |
124 | | - script (str): Script filename. |
125 | | - directory (str or None): Directory containing the source file. If it starts with |
126 | | - "s3://", no action is taken. |
127 | | - dependencies (List[str]): A list of paths to directories (absolute or relative) |
| 131 | + script (str): Script filename or path. |
| 132 | + directory (str): Optional. Directory containing the source file. If it starts with "s3://", |
| 133 | + no action is taken. |
| 134 | + dependencies (List[str]): Optional. A list of paths to directories (absolute or relative) |
128 | 135 | containing additional libraries that will be copied into |
129 | 136 | /opt/ml/lib |
130 | 137 |
|
131 | 138 | Returns: |
132 | | - sagemaker.fw_utils.UserCode: An object with the S3 bucket and key (S3 prefix) and script name. |
| 139 | + sagemaker.fw_utils.UserCode: An object with the S3 bucket and key (S3 prefix) and |
| 140 | + script name. |
133 | 141 | """ |
134 | | - dependencies = dependencies or [] |
135 | | - key = '%s/sourcedir.tar.gz' % s3_key_prefix |
136 | | - |
137 | 142 | if directory and directory.lower().startswith('s3://'): |
138 | 143 | return UploadedCode(s3_prefix=directory, script_name=os.path.basename(script)) |
139 | | - else: |
140 | | - tmp = tempfile.mkdtemp() |
141 | 144 |
|
142 | | - try: |
143 | | - source_files = _list_files_to_compress(script, directory) + dependencies |
144 | | - tar_file = sagemaker.utils.create_tar_file(source_files, os.path.join(tmp, _TAR_SOURCE_FILENAME)) |
| 145 | + script_name = script if directory else os.path.basename(script) |
| 146 | + dependencies = dependencies or [] |
| 147 | + key = '%s/sourcedir.tar.gz' % s3_key_prefix |
| 148 | + tmp = tempfile.mkdtemp() |
145 | 149 |
|
146 | | - session.resource('s3').Object(bucket, key).upload_file(tar_file) |
147 | | - finally: |
148 | | - shutil.rmtree(tmp) |
| 150 | + try: |
| 151 | + source_files = _list_files_to_compress(script, directory) + dependencies |
| 152 | + tar_file = sagemaker.utils.create_tar_file(source_files, |
| 153 | + os.path.join(tmp, _TAR_SOURCE_FILENAME)) |
149 | 154 |
|
150 | | - script_name = script if directory else os.path.basename(script) |
151 | | - return UploadedCode(s3_prefix='s3://%s/%s' % (bucket, key), script_name=script_name) |
| 155 | + session.resource('s3').Object(bucket, key).upload_file(tar_file) |
| 156 | + finally: |
| 157 | + shutil.rmtree(tmp) |
| 158 | + |
| 159 | + return UploadedCode(s3_prefix='s3://%s/%s' % (bucket, key), script_name=script_name) |
152 | 160 |
|
153 | 161 |
|
154 | 162 | def _list_files_to_compress(script, directory): |
| 163 | + if directory is None: |
| 164 | + return [script] |
| 165 | + |
155 | 166 | basedir = directory if directory else os.path.dirname(script) |
156 | 167 | return [os.path.join(basedir, name) for name in os.listdir(basedir)] |
157 | 168 |
|
|
0 commit comments