From 311ad62dc48d798e2d9e74f8ca92b521fc84b2ea Mon Sep 17 00:00:00 2001 From: garywan Date: Mon, 13 Jan 2025 22:40:28 +0000 Subject: [PATCH] use jumpstart deployment config image as default optimization image --- .../serve/builder/jumpstart_builder.py | 113 ++++++++++- .../serve/test_serve_js_deep_unit_tests.py | 18 ++ .../serve/builder/test_js_builder.py | 180 +++++++++++++++++- .../serve/builder/test_model_builder.py | 8 +- 4 files changed, 314 insertions(+), 5 deletions(-) diff --git a/src/sagemaker/serve/builder/jumpstart_builder.py b/src/sagemaker/serve/builder/jumpstart_builder.py index 37a77179cb..fa1bd3b516 100644 --- a/src/sagemaker/serve/builder/jumpstart_builder.py +++ b/src/sagemaker/serve/builder/jumpstart_builder.py @@ -17,7 +17,7 @@ import re from abc import ABC, abstractmethod from datetime import datetime, timedelta -from typing import Type, Any, List, Dict, Optional +from typing import Type, Any, List, Dict, Optional, Tuple import logging from botocore.exceptions import ClientError @@ -82,6 +82,7 @@ ModelServer.DJL_SERVING, ModelServer.TGI, } +_JS_MINIMUM_VERSION_IMAGE = "{}:0.31.0-lmi13.0.0-cu124" logger = logging.getLogger(__name__) @@ -829,7 +830,13 @@ def _optimize_for_jumpstart( self.pysdk_model._enable_network_isolation = False if quantization_config or sharding_config or is_compilation: - return create_optimization_job_args + # only apply default image for vLLM usecases. + # vLLM does not support compilation for now so skip on compilation + return ( + create_optimization_job_args + if is_compilation + else self._set_optimization_image_default(create_optimization_job_args) + ) return None def _is_gated_model(self, model=None) -> bool: @@ -986,3 +993,105 @@ def _get_neuron_model_env_vars( ) return job_model.env return None + + def _set_optimization_image_default( + self, create_optimization_job_args: Dict[str, Any] + ) -> Dict[str, Any]: + """Defaults the optimization image to the JumpStart deployment config default + + Args: + create_optimization_job_args (Dict[str, Any]): create optimization job request + + Returns: + Dict[str, Any]: create optimization job request with image uri default + """ + default_image = self._get_default_vllm_image(self.pysdk_model.init_kwargs["image_uri"]) + + # find the latest vLLM image version + for optimization_config in create_optimization_job_args.get("OptimizationConfigs"): + if optimization_config.get("ModelQuantizationConfig"): + model_quantization_config = optimization_config.get("ModelQuantizationConfig") + provided_image = model_quantization_config.get("Image") + if provided_image and self._get_latest_lmi_version_from_list( + default_image, provided_image + ): + default_image = provided_image + if optimization_config.get("ModelShardingConfig"): + model_sharding_config = optimization_config.get("ModelShardingConfig") + provided_image = model_sharding_config.get("Image") + if provided_image and self._get_latest_lmi_version_from_list( + default_image, provided_image + ): + default_image = provided_image + + # default to latest vLLM version + for optimization_config in create_optimization_job_args.get("OptimizationConfigs"): + if optimization_config.get("ModelQuantizationConfig") is not None: + optimization_config.get("ModelQuantizationConfig")["Image"] = default_image + if optimization_config.get("ModelShardingConfig") is not None: + optimization_config.get("ModelShardingConfig")["Image"] = default_image + + logger.info("Defaulting to %s image for optimization job", default_image) + + return create_optimization_job_args + + def _get_default_vllm_image(self, image: str) -> bool: + """Ensures the minimum working image version for vLLM enabled optimization techniques + + Args: + image (str): JumpStart provided default image + + Returns: + str: minimum working image version + """ + dlc_name, _ = image.split(":") + major_version_number, _, _ = self._parse_lmi_version(image) + + if int(major_version_number) < 13: + minimum_version_default = _JS_MINIMUM_VERSION_IMAGE.format(dlc_name) + return minimum_version_default + return image + + def _get_latest_lmi_version_from_list(self, version: str, version_to_compare: str) -> bool: + """LMI version comparator + + Args: + version (str): current version + version_to_compare (str): version to compare to + + Returns: + bool: if version_to_compare larger or equal to version + """ + parse_lmi_version = self._parse_lmi_version(version) + parse_lmi_version_to_compare = self._parse_lmi_version(version_to_compare) + + # Check major version + if parse_lmi_version_to_compare[0] > parse_lmi_version[0]: + return True + # Check minor version + if parse_lmi_version_to_compare[0] == parse_lmi_version[0]: + if parse_lmi_version_to_compare[1] > parse_lmi_version[1]: + return True + if parse_lmi_version_to_compare[1] == parse_lmi_version[1]: + # Check patch version + if parse_lmi_version_to_compare[2] >= parse_lmi_version[2]: + return True + return False + return False + return False + + def _parse_lmi_version(self, image: str) -> Tuple[int, int, int]: + """Parse out LMI version + + Args: + image (str): image to parse version out of + + Returns: + Tuple[int, int, it]: LMI version split into major, minor, patch + """ + _, dlc_tag = image.split(":") + _, lmi_version, _ = dlc_tag.split("-") + major_version, minor_version, patch_version = lmi_version.split(".") + major_version_number = major_version[3:] + + return (int(major_version_number), int(minor_version), int(patch_version)) diff --git a/tests/integ/sagemaker/serve/test_serve_js_deep_unit_tests.py b/tests/integ/sagemaker/serve/test_serve_js_deep_unit_tests.py index 348c57745f..e13e672bec 100644 --- a/tests/integ/sagemaker/serve/test_serve_js_deep_unit_tests.py +++ b/tests/integ/sagemaker/serve/test_serve_js_deep_unit_tests.py @@ -32,6 +32,8 @@ def test_js_model_with_optimize_speculative_decoding_config_gated_requests_are_e iam_client = sagemaker_session.boto_session.client("iam") role_arn = iam_client.get_role(RoleName=ROLE_NAME)["Role"]["Arn"] + sagemaker_session.sagemaker_client.create_optimization_job = MagicMock() + schema_builder = SchemaBuilder("test", "test") model_builder = ModelBuilder( model="meta-textgeneration-llama-3-1-8b-instruct", @@ -50,6 +52,8 @@ def test_js_model_with_optimize_speculative_decoding_config_gated_requests_are_e accept_eula=True, ) + assert not sagemaker_session.sagemaker_client.create_optimization_job.called + optimized_model.deploy() mock_create_model.assert_called_once_with( @@ -126,6 +130,13 @@ def test_js_model_with_optimize_sharding_and_resource_requirements_requests_are_ accept_eula=True, ) + assert ( + sagemaker_session.sagemaker_client.create_optimization_job.call_args_list[0][1][ + "OptimizationConfigs" + ][0]["ModelShardingConfig"]["Image"] + is not None + ) + optimized_model.deploy( resources=ResourceRequirements(requests={"memory": 196608, "num_accelerators": 8}) ) @@ -206,6 +217,13 @@ def test_js_model_with_optimize_quantization_on_pre_optimized_model_requests_are accept_eula=True, ) + assert ( + sagemaker_session.sagemaker_client.create_optimization_job.call_args_list[0][1][ + "OptimizationConfigs" + ][0]["ModelQuantizationConfig"]["Image"] + is not None + ) + optimized_model.deploy() mock_create_model.assert_called_once_with( diff --git a/tests/unit/sagemaker/serve/builder/test_js_builder.py b/tests/unit/sagemaker/serve/builder/test_js_builder.py index b6bd69e304..7a7bf4979c 100644 --- a/tests/unit/sagemaker/serve/builder/test_js_builder.py +++ b/tests/unit/sagemaker/serve/builder/test_js_builder.py @@ -75,7 +75,7 @@ "-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" ) mock_djl_image_uri = ( - "123456789712.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.24.0-neuronx-sdk2.14.1" + "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.31.0-lmi13.0.0-cu124" ) mock_model_data = { @@ -1166,6 +1166,9 @@ def test_optimize_quantize_for_jumpstart( mock_pysdk_model.image_uri = mock_tgi_image_uri mock_pysdk_model.list_deployment_configs.return_value = DEPLOYMENT_CONFIGS mock_pysdk_model.deployment_config = DEPLOYMENT_CONFIGS[0] + mock_pysdk_model.init_kwargs = { + "image_uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi11.0.0-cu124" + } sample_input = { "inputs": "The diamondback terrapin or simply terrapin is a species " @@ -1201,6 +1204,10 @@ def test_optimize_quantize_for_jumpstart( ) self.assertIsNotNone(out_put) + self.assertEqual( + out_put["OptimizationConfigs"][0]["ModelQuantizationConfig"]["Image"], + "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.31.0-lmi13.0.0-cu124", + ) @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) @@ -1287,6 +1294,9 @@ def test_optimize_quantize_and_compile_for_jumpstart( mock_pysdk_model.deployment_config = DEPLOYMENT_CONFIGS[0] mock_pysdk_model.config_name = "config_name" mock_pysdk_model._metadata_configs = {"config_name": mock_metadata_config} + mock_pysdk_model.init_kwargs = { + "image_uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi11.0.0-cu124" + } sample_input = { "inputs": "The diamondback terrapin or simply terrapin is a species " @@ -1319,6 +1329,8 @@ def test_optimize_quantize_and_compile_for_jumpstart( ) self.assertIsNotNone(out_put) + self.assertIsNone(out_put["OptimizationConfigs"][1]["ModelCompilationConfig"].get("Image")) + self.assertIsNone(out_put["OptimizationConfigs"][0]["ModelQuantizationConfig"].get("Image")) @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) @@ -1640,6 +1652,9 @@ def test_optimize_on_js_model_should_ignore_pre_optimized_configurations( mock_lmi_js_model = MagicMock() mock_lmi_js_model.image_uri = mock_djl_image_uri + mock_lmi_js_model.init_kwargs = { + "image_uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi11.0.0-cu124" + } mock_lmi_js_model.env = { "SAGEMAKER_PROGRAM": "inference.py", "ENDPOINT_SERVER_TIMEOUT": "3600", @@ -1718,6 +1733,9 @@ def test_optimize_on_js_model_should_ignore_pre_optimized_configurations_no_over mock_lmi_js_model = MagicMock() mock_lmi_js_model.image_uri = mock_djl_image_uri + mock_lmi_js_model.init_kwargs = { + "image_uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi11.0.0-cu124" + } mock_lmi_js_model.env = { "SAGEMAKER_PROGRAM": "inference.py", "ENDPOINT_SERVER_TIMEOUT": "3600", @@ -1763,3 +1781,163 @@ def test_optimize_on_js_model_should_ignore_pre_optimized_configurations_no_over "OPTION_TENSOR_PARALLEL_DEGREE": "8", "OPTION_QUANTIZE": "fp8", # should be added to the env } + + @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) + @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_gated_model", + return_value=True, + ) + @patch("sagemaker.serve.builder.jumpstart_builder.JumpStartModel") + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", + return_value=True, + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_fine_tuned_model", + return_value=False, + ) + def test_optimize_on_js_model_should_ignore_pre_optimized_configurations_no_override( + self, + mock_is_fine_tuned, + mock_is_jumpstart_model, + mock_js_model, + mock_is_gated_model, + mock_serve_settings, + mock_telemetry, + ): + + mock_lmi_js_model = MagicMock() + mock_lmi_js_model.image_uri = mock_djl_image_uri + mock_lmi_js_model.init_kwargs = { + "image_uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi11.0.0-cu124" + } + + model_builder = ModelBuilder( + model="meta-textgeneration-llama-3-1-70b-instruct", + schema_builder=SchemaBuilder("test", "test"), + sagemaker_session=MagicMock(), + ) + model_builder.pysdk_model = mock_lmi_js_model + + # assert lmi version is upgraded to hardcoded default + optimization_args = model_builder._set_optimization_image_default( + { + "OptimizationConfigs": [ + { + "ModelQuantizationConfig": { + "Image": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi11.0.0-cu124" + } + } + ] + } + ) + + self.assertEqual( + optimization_args["OptimizationConfigs"][0]["ModelQuantizationConfig"]["Image"], + "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.31.0-lmi13.0.0-cu124", + ) + + # assert lmi version is left as is + optimization_args = model_builder._set_optimization_image_default( + { + "OptimizationConfigs": [ + { + "ModelQuantizationConfig": { + "Image": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi21.0.0-cu124" + } + } + ] + } + ) + + self.assertEqual( + optimization_args["OptimizationConfigs"][0]["ModelQuantizationConfig"]["Image"], + "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi21.0.0-cu124", + ) + + # assert lmi version is upgraded to the highest provided version + optimization_args = model_builder._set_optimization_image_default( + { + "OptimizationConfigs": [ + { + "ModelShardingConfig": { + "Image": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi11.0.0-cu124" + } + }, + { + "ModelQuantizationConfig": { + "Image": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi30.0.0-cu124" + } + }, + ] + } + ) + + self.assertEqual( + optimization_args["OptimizationConfigs"][0]["ModelShardingConfig"]["Image"], + "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi30.0.0-cu124", + ) + self.assertEqual( + optimization_args["OptimizationConfigs"][1]["ModelQuantizationConfig"]["Image"], + "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi30.0.0-cu124", + ) + + # assert lmi version is upgraded to the highest provided version and sets empty image config + optimization_args = model_builder._set_optimization_image_default( + { + "OptimizationConfigs": [ + { + "ModelQuantizationConfig": { + "Image": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi30.0.0-cu124" + } + }, + {"ModelShardingConfig": {}}, + ] + } + ) + + self.assertEqual( + optimization_args["OptimizationConfigs"][0]["ModelQuantizationConfig"]["Image"], + "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi30.0.0-cu124", + ) + self.assertEqual( + optimization_args["OptimizationConfigs"][1]["ModelShardingConfig"]["Image"], + "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi30.0.0-cu124", + ) + + # assert lmi version is left as is on minor version bump + optimization_args = model_builder._set_optimization_image_default( + { + "OptimizationConfigs": [ + { + "ModelQuantizationConfig": { + "Image": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi13.1.0-cu124" + } + } + ] + } + ) + + self.assertEqual( + optimization_args["OptimizationConfigs"][0]["ModelQuantizationConfig"]["Image"], + "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi13.1.0-cu124", + ) + + # assert lmi version is left as is on patch version bump + optimization_args = model_builder._set_optimization_image_default( + { + "OptimizationConfigs": [ + { + "ModelQuantizationConfig": { + "Image": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi13.0.1-cu124" + } + } + ] + } + ) + + self.assertEqual( + optimization_args["OptimizationConfigs"][0]["ModelQuantizationConfig"]["Image"], + "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi13.0.1-cu124", + ) diff --git a/tests/unit/sagemaker/serve/builder/test_model_builder.py b/tests/unit/sagemaker/serve/builder/test_model_builder.py index 1e20bf1cf3..107d65c301 100644 --- a/tests/unit/sagemaker/serve/builder/test_model_builder.py +++ b/tests/unit/sagemaker/serve/builder/test_model_builder.py @@ -3733,6 +3733,9 @@ def test_optimize_sharding_with_override_for_js( pysdk_model.env = {"key": "val"} pysdk_model._enable_network_isolation = True pysdk_model.add_tags.side_effect = lambda *arg, **kwargs: None + pysdk_model.init_kwargs = { + "image_uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi11.0.0-cu124" + } mock_build_for_jumpstart.side_effect = lambda **kwargs: pysdk_model mock_prepare_for_mode.side_effect = lambda *args, **kwargs: ( @@ -3803,8 +3806,9 @@ def test_optimize_sharding_with_override_for_js( OptimizationConfigs=[ { "ModelShardingConfig": { - "OverrideEnvironment": {"OPTION_TENSOR_PARALLEL_DEGREE": "1"} - } + "Image": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.31.0-lmi13.0.0-cu124", + "OverrideEnvironment": {"OPTION_TENSOR_PARALLEL_DEGREE": "1"}, + }, } ], OutputConfig={