1818import pytest
1919import numpy
2020
21+ from sagemaker .chainer .defaults import CHAINER_VERSION
2122from sagemaker .chainer .estimator import Chainer
2223from sagemaker .chainer .model import ChainerModel
2324from sagemaker .utils import sagemaker_timestamp
2627
2728
2829@pytest .fixture (scope = 'module' )
29- def chainer_training_job (sagemaker_session ):
30- return _run_mnist_training_job (sagemaker_session , "ml.c4.xlarge" , 1 )
30+ def chainer_training_job (sagemaker_session , chainer_full_version ):
31+ return _run_mnist_training_job (sagemaker_session , "ml.c4.xlarge" , 1 , chainer_full_version )
3132
3233
33- def test_distributed_cpu_training (sagemaker_session ):
34- _run_mnist_training_job (sagemaker_session , "ml.c4.xlarge" , 2 )
34+ def test_distributed_cpu_training (sagemaker_session , chainer_full_version ):
35+ _run_mnist_training_job (sagemaker_session , "ml.c4.xlarge" , 2 , chainer_full_version )
3536
3637
37- def test_distributed_gpu_training (sagemaker_session ):
38- _run_mnist_training_job (sagemaker_session , "ml.p2.xlarge" , 2 )
38+ def test_distributed_gpu_training (sagemaker_session , chainer_full_version ):
39+ _run_mnist_training_job (sagemaker_session , "ml.p2.xlarge" , 2 , chainer_full_version )
3940
4041
41- def test_training_with_additional_hyperparameters (sagemaker_session ):
42+ def test_training_with_additional_hyperparameters (sagemaker_session , chainer_full_version ):
4243 with timeout (minutes = 15 ):
4344 script_path = os .path .join (DATA_DIR , 'chainer_mnist' , 'mnist.py' )
4445 data_path = os .path .join (DATA_DIR , 'chainer_mnist' )
4546
4647 chainer = Chainer (entry_point = script_path , role = 'SageMakerRole' ,
4748 train_instance_count = 1 , train_instance_type = "ml.c4.xlarge" ,
49+ framework_version = chainer_full_version ,
4850 sagemaker_session = sagemaker_session , hyperparameters = {'epochs' : 1 },
4951 use_mpi = True ,
5052 num_processes = 2 ,
@@ -75,8 +77,7 @@ def test_deploy_model(chainer_training_job, sagemaker_session):
7577 desc = sagemaker_session .sagemaker_client .describe_training_job (TrainingJobName = chainer_training_job )
7678 model_data = desc ['ModelArtifacts' ]['S3ModelArtifacts' ]
7779 script_path = os .path .join (DATA_DIR , 'chainer_mnist' , 'mnist.py' )
78- model = ChainerModel (model_data , 'SageMakerRole' , entry_point = script_path ,
79- sagemaker_session = sagemaker_session )
80+ model = ChainerModel (model_data , 'SageMakerRole' , entry_point = script_path , sagemaker_session = sagemaker_session )
8081 predictor = model .deploy (1 , "ml.m4.xlarge" , endpoint_name = endpoint_name )
8182 _predict_and_assert (predictor )
8283
@@ -85,7 +86,8 @@ def test_async_fit(sagemaker_session):
8586 endpoint_name = 'test-chainer-attach-deploy-{}' .format (sagemaker_timestamp ())
8687
8788 with timeout (minutes = 5 ):
88- training_job_name = _run_mnist_training_job (sagemaker_session , "ml.c4.xlarge" , 1 , wait = False )
89+ training_job_name = _run_mnist_training_job (sagemaker_session , "ml.c4.xlarge" , 1 ,
90+ chainer_full_version = CHAINER_VERSION , wait = False )
8991
9092 print ("Waiting to re-attach to the training job: %s" % training_job_name )
9193 time .sleep (20 )
@@ -97,12 +99,13 @@ def test_async_fit(sagemaker_session):
9799 _predict_and_assert (predictor )
98100
99101
100- def test_failed_training_job (sagemaker_session ):
102+ def test_failed_training_job (sagemaker_session , chainer_full_version ):
101103 with timeout (minutes = 15 ):
102104 script_path = os .path .join (DATA_DIR , 'chainer_mnist' , 'failure_script.py' )
103105 data_path = os .path .join (DATA_DIR , 'chainer_mnist' )
104106
105107 chainer = Chainer (entry_point = script_path , role = 'SageMakerRole' ,
108+ framework_version = chainer_full_version ,
106109 train_instance_count = 1 , train_instance_type = 'ml.c4.xlarge' ,
107110 sagemaker_session = sagemaker_session )
108111
@@ -113,7 +116,8 @@ def test_failed_training_job(sagemaker_session):
113116 chainer .fit (train_input )
114117
115118
116- def _run_mnist_training_job (sagemaker_session , instance_type , instance_count , wait = True ):
119+ def _run_mnist_training_job (sagemaker_session , instance_type , instance_count ,
120+ chainer_full_version , wait = True ):
117121 with timeout (minutes = 15 ):
118122
119123 script_path = os .path .join (DATA_DIR , 'chainer_mnist' , 'mnist.py' ) if instance_type == 1 else \
@@ -122,6 +126,7 @@ def _run_mnist_training_job(sagemaker_session, instance_type, instance_count, wa
122126 data_path = os .path .join (DATA_DIR , 'chainer_mnist' )
123127
124128 chainer = Chainer (entry_point = script_path , role = 'SageMakerRole' ,
129+ framework_version = chainer_full_version ,
125130 train_instance_count = instance_count , train_instance_type = instance_type ,
126131 sagemaker_session = sagemaker_session , hyperparameters = {'epochs' : 1 })
127132
0 commit comments