Skip to content

Commit eeb1df6

Browse files
gjwoodsGerard
and
Gerard
authored
[Open_Source_LLM] Add Endpoints Tests and switch Deployment setup (#842)
# Description 1. Put endpoint and deployment name into the "call" function to show the errors appropriately (and share the EndpointContainer) 2. Add tests for Endpoint & Deployment name functionality 3. Clean up old comments # All Promptflow Contribution checklist: - [x] **The pull request does not introduce [breaking changes].** - [x] **CHANGELOG is updated for new features, bug fixes or other significant changes.** - [x] **I have read the [contribution guidelines](../CONTRIBUTING.md).** - [x] **Create an issue and link to the pull request to get dedicated review from promptflow team. Learn more: [suggested workflow](../CONTRIBUTING.md#suggested-workflow).** ## General Guidelines and Best Practices - [x] Title of the pull request is clear and informative. - [x] There are a small number of commits, each of which have an informative message. This means that previously merged commits do not appear in the history of the PR. For more information on cleaning up the commits in your PR, [see this page](https://github.com/Azure/azure-powershell/blob/master/documentation/development-docs/cleaning-up-commits.md). ### Testing Guidelines - [x] Pull request includes test coverage for the included changes. --------- Co-authored-by: Gerard <[email protected]>
1 parent 4abf690 commit eeb1df6

File tree

4 files changed

+152
-66
lines changed

4 files changed

+152
-66
lines changed

src/promptflow-tools/connections.json.example

+10
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,16 @@
5151
"endpoint_api_key"
5252
]
5353
},
54+
"open_source_llm_ws_service_connection": {
55+
"type": "CustomConnection",
56+
"value": {
57+
"service_credential": "service-credential"
58+
},
59+
"module": "promptflow.connections",
60+
"secret_keys": [
61+
"service_credential"
62+
]
63+
},
5464
"open_ai_connection": {
5565
"type": "OpenAIConnection",
5666
"value": {

src/promptflow-tools/promptflow/tools/open_source_llm.py

+16-21
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def get_deployment_from_endpoint(endpoint_name: str, deployment_name: str = None
140140
return (endpoint_uri, endpoint_key, model)
141141

142142

143-
def get_deployment_from_connection(connection: CustomConnection, deployment_name: str = None) -> Tuple[str, str, str]:
143+
def get_deployment_from_connection(connection: CustomConnection) -> Tuple[str, str, str]:
144144
conn_dict = dict(connection)
145145
for key in REQUIRED_CONFIG_KEYS:
146146
if key not in conn_dict:
@@ -352,17 +352,7 @@ def get_content_formatter(
352352

353353

354354
class AzureMLOnlineEndpoint:
355-
"""Azure ML Online Endpoint models.
356-
357-
Example:
358-
.. code-block:: python
359-
360-
azure_llm = AzureMLModel(
361-
endpoint_url="https://<your-endpoint>.<your_region>.inference.ml.azure.com/score",
362-
endpoint_api_key="my-api-key",
363-
content_formatter=content_formatter,
364-
)
365-
""" # noqa: E501
355+
"""Azure ML Online Endpoint models."""
366356

367357
endpoint_url: str = ""
368358
"""URL of pre-existing Endpoint. Should be passed to constructor or specified as
@@ -453,32 +443,37 @@ class OpenSourceLLM(ToolProvider):
453443

454444
def __init__(self,
455445
connection: CustomConnection = None,
456-
endpoint_name: str = None,
457-
deployment_name: str = None):
446+
endpoint_name: str = None):
458447
super().__init__()
459448

460-
self.deployment_name = deployment_name
461-
if endpoint_name is not None and endpoint_name != DEFAULT_ENDPOINT_NAME:
462-
(self.endpoint_uri,
463-
self.endpoint_key,
464-
self.model_family) = get_deployment_from_endpoint(endpoint_name, deployment_name)
465-
else:
449+
self.endpoint_key = None
450+
self.endpoint_name = endpoint_name
451+
452+
if endpoint_name is None or endpoint_name == DEFAULT_ENDPOINT_NAME:
466453
(self.endpoint_uri,
467454
self.endpoint_key,
468-
self.model_family) = get_deployment_from_connection(connection, deployment_name)
455+
self.model_family) = get_deployment_from_connection(connection)
469456

470457
@tool
471458
@handle_oneline_endpoint_error()
472459
def call(
473460
self,
474461
prompt: PromptTemplate,
475462
api: API,
463+
deployment_name: str = None,
476464
temperature: float = 1.0,
477465
max_new_tokens: int = 500,
478466
top_p: float = 1.0,
479467
model_kwargs: Optional[Dict] = {},
480468
**kwargs
481469
) -> str:
470+
self.deployment_name = deployment_name
471+
472+
if self.endpoint_key is None and self.endpoint_name is not None:
473+
(self.endpoint_uri,
474+
self.endpoint_key,
475+
self.model_family) = get_deployment_from_endpoint(self.endpoint_name, self.deployment_name)
476+
482477
prompt = render_jinja_template(prompt, trim_blocks=True, keep_trailing_newline=True, **kwargs)
483478

484479
model_kwargs["top_p"] = top_p

src/promptflow-tools/tests/conftest.py

+13
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,19 @@ def llama_chat_custom_connection():
7777
return ConnectionManager().get("llama_chat_connection")
7878

7979

80+
@pytest.fixture
81+
def open_source_llm_ws_service_connection() -> bool:
82+
try:
83+
creds_custom_connection: CustomConnection = ConnectionManager().get("open_source_llm_ws_service_connection")
84+
subs = json.loads(creds_custom_connection.secrets['service_credential'])
85+
for key, value in subs.items():
86+
os.environ[key] = value
87+
return True
88+
except Exception as e:
89+
print(f'Something failed setting environment variables for service credentials. Error: {e}')
90+
return False
91+
92+
8093
@pytest.fixture(autouse=True)
8194
def skip_if_no_key(request, mocker):
8295
mocker.patch.dict(os.environ, {"PROMPTFLOW_CONNECTIONS": CONNECTION_FILE})

src/promptflow-tools/tests/test_open_source_llm.py

+113-45
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import copy
12
import os
23
import pytest
34
from promptflow.tools.exception import (
@@ -6,6 +7,7 @@
67
OpenSourceLLMKeyValidationError
78
)
89
from promptflow.tools.open_source_llm import OpenSourceLLM, API, ContentFormatterBase, LlamaContentFormatter
10+
from typing import List, Dict
911

1012

1113
@pytest.fixture
@@ -18,19 +20,66 @@ def llama_chat_provider(llama_chat_custom_connection) -> OpenSourceLLM:
1820
return OpenSourceLLM(llama_chat_custom_connection)
1921

2022

23+
@pytest.fixture
24+
def endpoints_provider(open_source_llm_ws_service_connection) -> Dict[str, List[str]]:
25+
if not open_source_llm_ws_service_connection:
26+
pytest.skip("Service Credential not available")
27+
28+
from azure.ai.ml import MLClient
29+
from azure.identity import DefaultAzureCredential
30+
credential = DefaultAzureCredential(exclude_interactive_browser_credential=False)
31+
ml_client = MLClient(
32+
credential=credential,
33+
subscription_id=os.getenv("AZUREML_ARM_SUBSCRIPTION"),
34+
resource_group_name=os.getenv("AZUREML_ARM_RESOURCEGROUP"),
35+
workspace_name=os.getenv("AZUREML_ARM_WORKSPACE_NAME"))
36+
37+
endpoints = {}
38+
for ep in ml_client.online_endpoints.list():
39+
endpoints[ep.name] = [d.name for d in ml_client.online_deployments.list(ep.name)]
40+
41+
return endpoints
42+
43+
44+
@pytest.fixture
45+
def chat_endpoints_provider(endpoints_provider: Dict[str, List[str]]) -> Dict[str, List[str]]:
46+
chat_endpoint_names = ["gpt2", "llama-chat"]
47+
48+
chat_endpoints = {}
49+
for key, value in endpoints_provider.items():
50+
for ep_name in chat_endpoint_names:
51+
if ep_name in key:
52+
chat_endpoints[key] = value
53+
54+
if len(chat_endpoints) <= 0:
55+
pytest.skip("No Chat Endpoints Found")
56+
57+
return chat_endpoints
58+
59+
60+
@pytest.fixture
61+
def completion_endpoints_provider(endpoints_provider: Dict[str, List[str]]) -> Dict[str, List[str]]:
62+
completion_endpoint_names = ["gpt2", "llama-comp"]
63+
64+
completion_endpoints = {}
65+
for key, value in endpoints_provider.items():
66+
for ep_name in completion_endpoint_names:
67+
if ep_name in key:
68+
completion_endpoints[key] = value
69+
70+
if len(completion_endpoints) <= 0:
71+
pytest.skip("No Completion Endpoints Found")
72+
73+
return completion_endpoints
74+
75+
2176
@pytest.mark.usefixtures("use_secrets_config_file")
2277
class TestOpenSourceLLM:
2378
completion_prompt = "In the context of Azure ML, what does the ML stand for?"
24-
25-
gpt2_chat_prompt = """system:
79+
chat_prompt = """system:
2680
You are a AI which helps Customers answer questions.
2781
2882
user:
29-
""" + completion_prompt
30-
31-
llama_chat_prompt = """system:
32-
You are a AI which helps Customers answer questions.
33-
3483
""" + completion_prompt
3584

3685
@pytest.mark.skip_if_no_key("gpt2_custom_connection")
@@ -41,56 +90,54 @@ def test_open_source_llm_completion(self, gpt2_provider):
4190
assert len(response) > 25
4291

4392
@pytest.mark.skip_if_no_key("gpt2_custom_connection")
44-
def test_open_source_llm_completion_with_deploy(self, gpt2_custom_connection):
45-
os_tool = OpenSourceLLM(
46-
gpt2_custom_connection,
47-
deployment_name="gpt2-9")
48-
response = os_tool.call(
93+
def test_open_source_llm_completion_with_deploy(self, gpt2_provider):
94+
response = gpt2_provider.call(
4995
self.completion_prompt,
50-
API.COMPLETION)
96+
API.COMPLETION,
97+
deployment_name="gpt2-9")
5198
assert len(response) > 25
5299

53100
@pytest.mark.skip_if_no_key("gpt2_custom_connection")
54101
def test_open_source_llm_chat(self, gpt2_provider):
55102
response = gpt2_provider.call(
56-
self.gpt2_chat_prompt,
103+
self.chat_prompt,
57104
API.CHAT)
58105
assert len(response) > 25
59106

60107
@pytest.mark.skip_if_no_key("gpt2_custom_connection")
61-
def test_open_source_llm_chat_with_deploy(self, gpt2_custom_connection):
62-
os_tool = OpenSourceLLM(
63-
gpt2_custom_connection,
108+
def test_open_source_llm_chat_with_deploy(self, gpt2_provider):
109+
response = gpt2_provider.call(
110+
self.chat_prompt,
111+
API.CHAT,
64112
deployment_name="gpt2-9")
65-
response = os_tool.call(
66-
self.gpt2_chat_prompt,
67-
API.CHAT)
68113
assert len(response) > 25
69114

70115
@pytest.mark.skip_if_no_key("gpt2_custom_connection")
71116
def test_open_source_llm_chat_with_max_length(self, gpt2_provider):
72117
response = gpt2_provider.call(
73-
self.gpt2_chat_prompt,
118+
self.chat_prompt,
74119
API.CHAT,
75120
max_new_tokens=2)
76121
# GPT-2 doesn't take this parameter
77122
assert len(response) > 25
78123

79124
@pytest.mark.skip_if_no_key("gpt2_custom_connection")
80125
def test_open_source_llm_con_url_chat(self, gpt2_custom_connection):
81-
del gpt2_custom_connection.configs['endpoint_url']
126+
tmp = copy.deepcopy(gpt2_custom_connection)
127+
del tmp.configs['endpoint_url']
82128
with pytest.raises(OpenSourceLLMKeyValidationError) as exc_info:
83-
os = OpenSourceLLM(gpt2_custom_connection)
129+
os = OpenSourceLLM(tmp)
84130
os.call(self.chat_prompt, API.CHAT)
85131
assert exc_info.value.message == """Required key `endpoint_url` not found in given custom connection.
86132
Required keys are: endpoint_url,model_family."""
87133
assert exc_info.value.error_codes == "UserError/ToolValidationError/OpenSourceLLMKeyValidationError".split("/")
88134

89135
@pytest.mark.skip_if_no_key("gpt2_custom_connection")
90136
def test_open_source_llm_con_key_chat(self, gpt2_custom_connection):
91-
del gpt2_custom_connection.secrets['endpoint_api_key']
137+
tmp = copy.deepcopy(gpt2_custom_connection)
138+
del tmp.secrets['endpoint_api_key']
92139
with pytest.raises(OpenSourceLLMKeyValidationError) as exc_info:
93-
os = OpenSourceLLM(gpt2_custom_connection)
140+
os = OpenSourceLLM(tmp)
94141
os.call(self.chat_prompt, API.CHAT)
95142
assert exc_info.value.message == (
96143
"Required secret key `endpoint_api_key` "
@@ -100,9 +147,10 @@ def test_open_source_llm_con_key_chat(self, gpt2_custom_connection):
100147

101148
@pytest.mark.skip_if_no_key("gpt2_custom_connection")
102149
def test_open_source_llm_con_model_chat(self, gpt2_custom_connection):
103-
del gpt2_custom_connection.configs['model_family']
150+
tmp = copy.deepcopy(gpt2_custom_connection)
151+
del tmp.configs['model_family']
104152
with pytest.raises(OpenSourceLLMKeyValidationError) as exc_info:
105-
os = OpenSourceLLM(gpt2_custom_connection)
153+
os = OpenSourceLLM(tmp)
106154
os.call(self.completion_prompt, API.COMPLETION)
107155
assert exc_info.value.message == """Required key `model_family` not found in given custom connection.
108156
Required keys are: endpoint_url,model_family."""
@@ -114,7 +162,7 @@ def test_open_source_llm_escape_chat(self):
114162
assert out_of_danger == "The quick \\brown fox\\tjumped\\\\over \\the \\\\boy\\r\\n"
115163

116164
def test_open_source_llm_llama_parse_chat_with_chat(self):
117-
LlamaContentFormatter.parse_chat(self.llama_chat_prompt)
165+
LlamaContentFormatter.parse_chat(self.chat_prompt)
118166

119167
def test_open_source_llm_llama_parse_multi_turn(self):
120168
multi_turn_chat = """user:
@@ -163,8 +211,9 @@ def test_open_source_llm_llama_parse_chat_with_comp(self):
163211

164212
@pytest.mark.skip_if_no_key("gpt2_custom_connection")
165213
def test_open_source_llm_llama_endpoint_miss(self, gpt2_custom_connection):
166-
gpt2_custom_connection.configs['endpoint_url'] += 'completely/real/endpoint'
167-
os = OpenSourceLLM(gpt2_custom_connection)
214+
tmp = copy.deepcopy(gpt2_custom_connection)
215+
tmp.configs['endpoint_url'] += 'completely/real/endpoint'
216+
os = OpenSourceLLM(tmp)
168217
with pytest.raises(OpenSourceLLMOnlineEndpointError) as exc_info:
169218
os.call(
170219
self.completion_prompt,
@@ -175,30 +224,49 @@ def test_open_source_llm_llama_endpoint_miss(self, gpt2_custom_connection):
175224
assert exc_info.value.error_codes == "UserError/OpenSourceLLMOnlineEndpointError".split("/")
176225

177226
@pytest.mark.skip_if_no_key("gpt2_custom_connection")
178-
def test_open_source_llm_llama_deployment_miss(self, gpt2_custom_connection):
179-
os = OpenSourceLLM(
180-
gpt2_custom_connection,
181-
deployment_name="completely/real/deployment-007")
227+
def test_open_source_llm_llama_deployment_miss(self, gpt2_provider):
182228
with pytest.raises(OpenSourceLLMOnlineEndpointError) as exc_info:
183-
os.call(self.completion_prompt, API.COMPLETION)
229+
gpt2_provider.call(self.completion_prompt,
230+
API.COMPLETION,
231+
deployment_name="completely/real/deployment-007")
184232
assert exc_info.value.message == (
185233
"Exception hit calling Oneline Endpoint: "
186234
+ "HTTPError: HTTP Error 404: Not Found")
187235
assert exc_info.value.error_codes == "UserError/OpenSourceLLMOnlineEndpointError".split("/")
188236

189-
@pytest.mark.skip
190-
def test_open_source_llm_endpoint_name(self):
191-
os.environ["AZUREML_ARM_SUBSCRIPTION"] = "<needs_value>"
192-
os.environ["AZUREML_ARM_RESOURCEGROUP"] = "<needs_value>"
193-
os.environ["AZUREML_ARM_WORKSPACE_NAME"] = "<needs_value>"
194-
195-
os_llm = OpenSourceLLM(endpoint_name="llama-temp-chat")
196-
response = os_llm.call(self.llama_chat_prompt, API.CHAT)
197-
assert len(response) > 25
237+
@pytest.mark.skip_if_no_key("open_source_llm_ws_service_connection")
238+
def test_open_source_llm_chat_endpoint_name(self, chat_endpoints_provider):
239+
for endpoint_name in chat_endpoints_provider:
240+
os_llm = OpenSourceLLM(endpoint_name=endpoint_name)
241+
response = os_llm.call(self.chat_prompt, API.CHAT)
242+
assert len(response) > 25
243+
244+
@pytest.mark.skip_if_no_key("open_source_llm_ws_service_connection")
245+
def test_open_source_llm_chat_endpoint_name_with_deployment(self, chat_endpoints_provider):
246+
for endpoint_name in chat_endpoints_provider:
247+
os_llm = OpenSourceLLM(endpoint_name=endpoint_name)
248+
for deployment_name in chat_endpoints_provider[endpoint_name]:
249+
response = os_llm.call(self.chat_prompt, API.CHAT, deployment_name=deployment_name)
250+
assert len(response) > 25
251+
252+
@pytest.mark.skip_if_no_key("open_source_llm_ws_service_connection")
253+
def test_open_source_llm_completion_endpoint_name(self, completion_endpoints_provider):
254+
for endpoint_name in completion_endpoints_provider:
255+
os_llm = OpenSourceLLM(endpoint_name=endpoint_name)
256+
response = os_llm.call(self.completion_prompt, API.COMPLETION)
257+
assert len(response) > 25
258+
259+
@pytest.mark.skip_if_no_key("open_source_llm_ws_service_connection")
260+
def test_open_source_llm_completion_endpoint_name_with_deployment(self, completion_endpoints_provider):
261+
for endpoint_name in completion_endpoints_provider:
262+
os_llm = OpenSourceLLM(endpoint_name=endpoint_name)
263+
for deployment_name in completion_endpoints_provider[endpoint_name]:
264+
response = os_llm.call(self.completion_prompt, API.COMPLETION, deployment_name=deployment_name)
265+
assert len(response) > 25
198266

199267
@pytest.mark.skip_if_no_key("llama_chat_custom_connection")
200268
def test_open_source_llm_llama_chat(self, llama_chat_provider):
201-
response = llama_chat_provider.call(self.llama_chat_prompt, API.CHAT)
269+
response = llama_chat_provider.call(self.chat_prompt, API.CHAT)
202270
assert len(response) > 25
203271

204272
@pytest.mark.skip_if_no_key("llama_chat_custom_connection")

0 commit comments

Comments
 (0)