diff --git a/docs/reference/tools-reference/open_source_llm_tool.md b/docs/reference/tools-reference/open_source_llm_tool.md index ae0cca08823..f1ec9efbd4f 100644 --- a/docs/reference/tools-reference/open_source_llm_tool.md +++ b/docs/reference/tools-reference/open_source_llm_tool.md @@ -41,6 +41,7 @@ The keys to set are: | api | string | this will be Completion or Chat, depending on the scenario selected | Yes | | connection | CustomConnection | the name of the connection which points to the Inferencing endpoint | Yes | | model_kwargs | dictionary | generic model configuration values, for example temperature | Yes | +| deployment_name | string | the name of the deployment to target on the MIR endpoint. If no value is passed, the MIR load balancer settings will be used. | No | | prompt | string | text prompt that the language model will complete | Yes | ## Outputs diff --git a/src/promptflow-tools/promptflow/tools/open_source_llm.py b/src/promptflow-tools/promptflow/tools/open_source_llm.py index ae2d67a9f4e..deca3c321ea 100644 --- a/src/promptflow-tools/promptflow/tools/open_source_llm.py +++ b/src/promptflow-tools/promptflow/tools/open_source_llm.py @@ -292,10 +292,12 @@ def __init__( endpoint_url: str, endpoint_api_key: str, content_formatter: ContentFormatterBase, + deployment_name: Optional[str] = None, model_kwargs: Optional[Dict] = None, ): self.endpoint_url = endpoint_url self.endpoint_api_key = endpoint_api_key + self.deployment_name = deployment_name self.content_formatter = content_formatter self.model_kwargs = model_kwargs @@ -318,9 +320,8 @@ def _call_endpoint(self, body: bytes) -> bytes: headers = {"Content-Type": "application/json", "Authorization": ("Bearer " + self.endpoint_api_key)} # If this is not set it'll use the default deployment on the endpoint. - if mir_deployment_name_config in self.model_kwargs: - headers[mir_deployment_name_config] = self.model_kwargs[mir_deployment_name_config] - del self.model_kwargs[mir_deployment_name_config] + if self.deployment_name is not None: + headers["azureml-model-deployment"] = self.deployment_name req = urllib.request.Request(self.endpoint_url, body, headers) response = urllib.request.urlopen(req, timeout=50) @@ -388,6 +389,7 @@ def call( self, prompt: PromptTemplate, api: API, + deployment_name: Optional[str] = None, model_kwargs: Optional[Dict] = {}, **kwargs ) -> str: @@ -403,6 +405,7 @@ def call( endpoint_url=self.connection.configs['endpoint_url'], endpoint_api_key=self.connection.secrets['endpoint_api_key'], content_formatter=content_formatter, + deployment_name=deployment_name, model_kwargs=model_kwargs ) diff --git a/src/promptflow-tools/promptflow/tools/yamls/open_source_llm.yaml b/src/promptflow-tools/promptflow/tools/yamls/open_source_llm.yaml index 416bff5f5d1..f5bbd809164 100644 --- a/src/promptflow-tools/promptflow/tools/yamls/open_source_llm.yaml +++ b/src/promptflow-tools/promptflow/tools/yamls/open_source_llm.yaml @@ -15,6 +15,10 @@ promptflow.tools.open_source_llm.OpenSourceLLM.call: - completion type: - string + deployment_name: + default: null + type: + - string model_kwargs: default: "{}" type: diff --git a/src/promptflow-tools/tests/test_open_source_llm.py b/src/promptflow-tools/tests/test_open_source_llm.py index 928dd611f13..b5729c190a7 100644 --- a/src/promptflow-tools/tests/test_open_source_llm.py +++ b/src/promptflow-tools/tests/test_open_source_llm.py @@ -28,6 +28,14 @@ def test_open_source_llm_completion(self, gpt2_provider): API.COMPLETION) assert len(response) > 25 + @pytest.mark.skip_if_no_key("gpt2_custom_connection") + def test_open_source_llm_completion_with_deploy(self, gpt2_provider): + response = gpt2_provider.call( + self.completion_prompt, + API.COMPLETION, + deployment_name="gpt2-8") + assert len(response) > 25 + @pytest.mark.skip_if_no_key("gpt2_custom_connection") def test_open_source_llm_chat(self, gpt2_provider): response = gpt2_provider.call( @@ -35,6 +43,14 @@ def test_open_source_llm_chat(self, gpt2_provider): API.CHAT) assert len(response) > 25 + @pytest.mark.skip_if_no_key("gpt2_custom_connection") + def test_open_source_llm_chat_with_deploy(self, gpt2_provider): + response = gpt2_provider.call( + self.chat_prompt, + API.CHAT, + deployment_name="gpt2-8") + assert len(response) > 25 + @pytest.mark.skip_if_no_key("gpt2_custom_connection") def test_open_source_llm_con_url_chat(self, gpt2_custom_connection): del gpt2_custom_connection.configs['endpoint_url'] @@ -121,7 +137,7 @@ def test_open_source_llm_llama_parse_chat_with_comp(self): pass @pytest.mark.skip_if_no_key("gpt2_custom_connection") - def test_open_source_llm_llama_req_chat(self, gpt2_custom_connection): + def test_open_source_llm_llama_endpoint_miss(self, gpt2_custom_connection): gpt2_custom_connection.configs['endpoint_url'] += 'completely/real/endpoint' os = OpenSourceLLM(gpt2_custom_connection) try: @@ -130,3 +146,14 @@ def test_open_source_llm_llama_req_chat(self, gpt2_custom_connection): API.COMPLETION) except OpenSourceLLMOnlineEndpointError: pass + + @pytest.mark.skip_if_no_key("gpt2_custom_connection") + def test_open_source_llm_llama_deployment_miss(self, gpt2_custom_connection): + os = OpenSourceLLM(gpt2_custom_connection) + try: + os.call( + self.completion_prompt, + API.COMPLETION, + deployment_name="completely/real/deployment-007") + except OpenSourceLLMOnlineEndpointError: + pass