1313from __future__ import absolute_import
1414
1515import inspect
16- from mock import Mock , patch
1716import os
18- from sagemaker .fw_utils import create_image_uri , framework_name_from_image , framework_version_from_tag , \
19- model_code_key_prefix
20- from sagemaker .fw_utils import tar_and_upload_dir , parse_s3_url , UploadedCode , validate_source_dir
17+
2118import pytest
19+ from mock import Mock , patch
2220
21+ from sagemaker .fw_utils import create_image_uri , framework_name_from_image , \
22+ framework_version_from_tag , \
23+ model_code_key_prefix
24+ from sagemaker .fw_utils import tar_and_upload_dir , parse_s3_url , UploadedCode , validate_source_dir
2325from sagemaker .utils import name_from_image
2426
2527DATA_DIR = 'data_dir'
3335@pytest .fixture ()
3436def sagemaker_session ():
3537 boto_mock = Mock (name = 'boto_session' , region_name = REGION )
36- ims = Mock (name = 'sagemaker_session' , boto_session = boto_mock )
37- ims .default_bucket = Mock (name = 'default_bucket' , return_value = BUCKET_NAME )
38- ims .expand_role = Mock (name = "expand_role" , return_value = ROLE )
39- ims .sagemaker_client .describe_training_job = Mock ( return_value = { 'ModelArtifacts' :
40- {'S3ModelArtifacts' : 's3://m/m.tar.gz' }})
41- return ims
38+ session_mock = Mock (name = 'sagemaker_session' , boto_session = boto_mock )
39+ session_mock .default_bucket = Mock (name = 'default_bucket' , return_value = BUCKET_NAME )
40+ session_mock .expand_role = Mock (name = "expand_role" , return_value = ROLE )
41+ session_mock .sagemaker_client .describe_training_job = \
42+ Mock ( return_value = { 'ModelArtifacts' : {'S3ModelArtifacts' : 's3://m/m.tar.gz' }})
43+ return session_mock
4244
4345
4446def test_create_image_uri_cpu ():
@@ -49,6 +51,16 @@ def test_create_image_uri_cpu():
4951 assert image_uri == '23.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-mlfw:1.0rc-cpu-py2'
5052
5153
54+ def test_create_image_uri_no_python ():
55+ image_uri = create_image_uri ('mars-south-3' , 'mlfw' , 'ml.c4.large' , '1.0rc' , account = '23' )
56+ assert image_uri == '23.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-mlfw:1.0rc-cpu'
57+
58+
59+ def test_create_image_uri_bad_python ():
60+ with pytest .raises (ValueError ):
61+ create_image_uri ('mars-south-3' , 'mlfw' , 'ml.c4.large' , '1.0rc' , 'py0' )
62+
63+
5264def test_create_image_uri_gpu ():
5365 image_uri = create_image_uri ('mars-south-3' , 'mlfw' , 'ml.p3.2xlarge' , '1.0rc' , 'py3' , '23' )
5466 assert image_uri == '23.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-mlfw:1.0rc-gpu-py3'
@@ -127,7 +139,8 @@ def test_tar_and_upload_dir_not_s3(sagemaker_session):
127139 script = os .path .basename (__file__ )
128140 directory = os .path .dirname (os .path .abspath (inspect .getfile (inspect .currentframe ())))
129141 result = tar_and_upload_dir (sagemaker_session , bucket , s3_key_prefix , script , directory )
130- assert result == UploadedCode ('s3://{}/{}/sourcedir.tar.gz' .format (bucket , s3_key_prefix ), script )
142+ assert result == UploadedCode ('s3://{}/{}/sourcedir.tar.gz' .format (bucket , s3_key_prefix ),
143+ script )
131144
132145
133146def test_framework_name_from_image_mxnet ():
@@ -149,21 +162,24 @@ def test_legacy_name_from_framework_image():
149162
150163
151164def test_legacy_name_from_wrong_framework ():
152- framework , py_ver , tag = framework_name_from_image ('123.dkr.ecr.us-west-2.amazonaws.com/sagemaker-myown-py2-gpu:1' )
165+ framework , py_ver , tag = framework_name_from_image (
166+ '123.dkr.ecr.us-west-2.amazonaws.com/sagemaker-myown-py2-gpu:1' )
153167 assert framework is None
154168 assert py_ver is None
155169 assert tag is None
156170
157171
158172def test_legacy_name_from_wrong_python ():
159- framework , py_ver , tag = framework_name_from_image ('123.dkr.ecr.us-west-2.amazonaws.com/sagemaker-myown-py4-gpu:1' )
173+ framework , py_ver , tag = framework_name_from_image (
174+ '123.dkr.ecr.us-west-2.amazonaws.com/sagemaker-myown-py4-gpu:1' )
160175 assert framework is None
161176 assert py_ver is None
162177 assert tag is None
163178
164179
165180def test_legacy_name_from_wrong_device ():
166- framework , py_ver , tag = framework_name_from_image ('123.dkr.ecr.us-west-2.amazonaws.com/sagemaker-myown-py4-gpu:1' )
181+ framework , py_ver , tag = framework_name_from_image (
182+ '123.dkr.ecr.us-west-2.amazonaws.com/sagemaker-myown-py4-gpu:1' )
167183 assert framework is None
168184 assert py_ver is None
169185 assert tag is None
0 commit comments