@@ -627,8 +627,12 @@ class Framework(EstimatorBase):
627627 such as training/deployment images and predictor instances.
628628 """
629629
630+ _DISTRIBUTION_SUPPORTED_FRAMEWORKS = ('mxnet' ,)
631+ LAUNCH_PS_ENV_NAME = 'sagemaker_parameter_server_enabled'
632+
630633 def __init__ (self , entry_point , source_dir = None , hyperparameters = None , enable_cloudwatch_metrics = False ,
631- container_log_level = logging .INFO , code_location = None , image_name = None , ** kwargs ):
634+ container_log_level = logging .INFO , code_location = None , image_name = None ,
635+ distributions = None , ** kwargs ):
632636 """Base class initializer. Subclasses which override ``__init__`` should invoke ``super()``
633637
634638 Args:
@@ -650,6 +654,8 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None, enable_cl
650654 image_name (str): An alternate image name to use instead of the official Sagemaker image
651655 for the framework. This is useful to run one of the Sagemaker supported frameworks
652656 with an image containing custom dependencies.
657+ distributions (dict): A dictionary with information on how to run distributed training
658+ (default: None).
653659 **kwargs: Additional kwargs passed to the ``EstimatorBase`` constructor.
654660 """
655661 super (Framework , self ).__init__ (** kwargs )
@@ -660,10 +666,27 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None, enable_cl
660666 DeprecationWarning )
661667 self .enable_cloudwatch_metrics = False
662668 self .container_log_level = container_log_level
663- self ._hyperparameters = hyperparameters or {}
664669 self .code_location = code_location
665670 self .image_name = image_name
666671
672+ self ._hyperparameters = hyperparameters or {}
673+ self ._configure_distributions (distributions )
674+
675+ def _configure_distributions (self , distributions ):
676+ if distributions is None :
677+ return
678+
679+ if self .__framework_name__ not in self ._DISTRIBUTION_SUPPORTED_FRAMEWORKS :
680+ raise ValueError ('This framework does not support the distributions option.' )
681+
682+ if self .framework_version .split ('.' ) < self ._LOWEST_SCRIPT_MODE_VERSION :
683+ raise ValueError ('The distributions option is valid for only versions {} and higher'
684+ .format ('.' .join (self ._LOWEST_SCRIPT_MODE_VERSION )))
685+
686+ if 'parameter_server' in distributions :
687+ enabled = distributions ['parameter_server' ].get ('enabled' , False )
688+ self ._hyperparameters [self .LAUNCH_PS_ENV_NAME ] = enabled
689+
667690 def _prepare_for_training (self , job_name = None ):
668691 """Set hyperparameters needed for training. This method will also validate ``source_dir``.
669692
0 commit comments