1
+ import copy
1
2
import os
2
3
import pytest
3
4
from promptflow .tools .exception import (
6
7
OpenSourceLLMKeyValidationError
7
8
)
8
9
from promptflow .tools .open_source_llm import OpenSourceLLM , API , ContentFormatterBase , LlamaContentFormatter
10
+ from typing import List , Dict
9
11
10
12
11
13
@pytest .fixture
@@ -18,19 +20,66 @@ def llama_chat_provider(llama_chat_custom_connection) -> OpenSourceLLM:
18
20
return OpenSourceLLM (llama_chat_custom_connection )
19
21
20
22
23
+ @pytest .fixture
24
+ def endpoints_provider (open_source_llm_ws_service_connection ) -> Dict [str , List [str ]]:
25
+ if not open_source_llm_ws_service_connection :
26
+ pytest .skip ("Service Credential not available" )
27
+
28
+ from azure .ai .ml import MLClient
29
+ from azure .identity import DefaultAzureCredential
30
+ credential = DefaultAzureCredential (exclude_interactive_browser_credential = False )
31
+ ml_client = MLClient (
32
+ credential = credential ,
33
+ subscription_id = os .getenv ("AZUREML_ARM_SUBSCRIPTION" ),
34
+ resource_group_name = os .getenv ("AZUREML_ARM_RESOURCEGROUP" ),
35
+ workspace_name = os .getenv ("AZUREML_ARM_WORKSPACE_NAME" ))
36
+
37
+ endpoints = {}
38
+ for ep in ml_client .online_endpoints .list ():
39
+ endpoints [ep .name ] = [d .name for d in ml_client .online_deployments .list (ep .name )]
40
+
41
+ return endpoints
42
+
43
+
44
+ @pytest .fixture
45
+ def chat_endpoints_provider (endpoints_provider : Dict [str , List [str ]]) -> Dict [str , List [str ]]:
46
+ chat_endpoint_names = ["gpt2" , "llama-chat" ]
47
+
48
+ chat_endpoints = {}
49
+ for key , value in endpoints_provider .items ():
50
+ for ep_name in chat_endpoint_names :
51
+ if ep_name in key :
52
+ chat_endpoints [key ] = value
53
+
54
+ if len (chat_endpoints ) <= 0 :
55
+ pytest .skip ("No Chat Endpoints Found" )
56
+
57
+ return chat_endpoints
58
+
59
+
60
+ @pytest .fixture
61
+ def completion_endpoints_provider (endpoints_provider : Dict [str , List [str ]]) -> Dict [str , List [str ]]:
62
+ completion_endpoint_names = ["gpt2" , "llama-comp" ]
63
+
64
+ completion_endpoints = {}
65
+ for key , value in endpoints_provider .items ():
66
+ for ep_name in completion_endpoint_names :
67
+ if ep_name in key :
68
+ completion_endpoints [key ] = value
69
+
70
+ if len (completion_endpoints ) <= 0 :
71
+ pytest .skip ("No Completion Endpoints Found" )
72
+
73
+ return completion_endpoints
74
+
75
+
21
76
@pytest .mark .usefixtures ("use_secrets_config_file" )
22
77
class TestOpenSourceLLM :
23
78
completion_prompt = "In the context of Azure ML, what does the ML stand for?"
24
-
25
- gpt2_chat_prompt = """system:
79
+ chat_prompt = """system:
26
80
You are a AI which helps Customers answer questions.
27
81
28
82
user:
29
- """ + completion_prompt
30
-
31
- llama_chat_prompt = """system:
32
- You are a AI which helps Customers answer questions.
33
-
34
83
""" + completion_prompt
35
84
36
85
@pytest .mark .skip_if_no_key ("gpt2_custom_connection" )
@@ -41,56 +90,54 @@ def test_open_source_llm_completion(self, gpt2_provider):
41
90
assert len (response ) > 25
42
91
43
92
@pytest .mark .skip_if_no_key ("gpt2_custom_connection" )
44
- def test_open_source_llm_completion_with_deploy (self , gpt2_custom_connection ):
45
- os_tool = OpenSourceLLM (
46
- gpt2_custom_connection ,
47
- deployment_name = "gpt2-9" )
48
- response = os_tool .call (
93
+ def test_open_source_llm_completion_with_deploy (self , gpt2_provider ):
94
+ response = gpt2_provider .call (
49
95
self .completion_prompt ,
50
- API .COMPLETION )
96
+ API .COMPLETION ,
97
+ deployment_name = "gpt2-9" )
51
98
assert len (response ) > 25
52
99
53
100
@pytest .mark .skip_if_no_key ("gpt2_custom_connection" )
54
101
def test_open_source_llm_chat (self , gpt2_provider ):
55
102
response = gpt2_provider .call (
56
- self .gpt2_chat_prompt ,
103
+ self .chat_prompt ,
57
104
API .CHAT )
58
105
assert len (response ) > 25
59
106
60
107
@pytest .mark .skip_if_no_key ("gpt2_custom_connection" )
61
- def test_open_source_llm_chat_with_deploy (self , gpt2_custom_connection ):
62
- os_tool = OpenSourceLLM (
63
- gpt2_custom_connection ,
108
+ def test_open_source_llm_chat_with_deploy (self , gpt2_provider ):
109
+ response = gpt2_provider .call (
110
+ self .chat_prompt ,
111
+ API .CHAT ,
64
112
deployment_name = "gpt2-9" )
65
- response = os_tool .call (
66
- self .gpt2_chat_prompt ,
67
- API .CHAT )
68
113
assert len (response ) > 25
69
114
70
115
@pytest .mark .skip_if_no_key ("gpt2_custom_connection" )
71
116
def test_open_source_llm_chat_with_max_length (self , gpt2_provider ):
72
117
response = gpt2_provider .call (
73
- self .gpt2_chat_prompt ,
118
+ self .chat_prompt ,
74
119
API .CHAT ,
75
120
max_new_tokens = 2 )
76
121
# GPT-2 doesn't take this parameter
77
122
assert len (response ) > 25
78
123
79
124
@pytest .mark .skip_if_no_key ("gpt2_custom_connection" )
80
125
def test_open_source_llm_con_url_chat (self , gpt2_custom_connection ):
81
- del gpt2_custom_connection .configs ['endpoint_url' ]
126
+ tmp = copy .deepcopy (gpt2_custom_connection )
127
+ del tmp .configs ['endpoint_url' ]
82
128
with pytest .raises (OpenSourceLLMKeyValidationError ) as exc_info :
83
- os = OpenSourceLLM (gpt2_custom_connection )
129
+ os = OpenSourceLLM (tmp )
84
130
os .call (self .chat_prompt , API .CHAT )
85
131
assert exc_info .value .message == """Required key `endpoint_url` not found in given custom connection.
86
132
Required keys are: endpoint_url,model_family."""
87
133
assert exc_info .value .error_codes == "UserError/ToolValidationError/OpenSourceLLMKeyValidationError" .split ("/" )
88
134
89
135
@pytest .mark .skip_if_no_key ("gpt2_custom_connection" )
90
136
def test_open_source_llm_con_key_chat (self , gpt2_custom_connection ):
91
- del gpt2_custom_connection .secrets ['endpoint_api_key' ]
137
+ tmp = copy .deepcopy (gpt2_custom_connection )
138
+ del tmp .secrets ['endpoint_api_key' ]
92
139
with pytest .raises (OpenSourceLLMKeyValidationError ) as exc_info :
93
- os = OpenSourceLLM (gpt2_custom_connection )
140
+ os = OpenSourceLLM (tmp )
94
141
os .call (self .chat_prompt , API .CHAT )
95
142
assert exc_info .value .message == (
96
143
"Required secret key `endpoint_api_key` "
@@ -100,9 +147,10 @@ def test_open_source_llm_con_key_chat(self, gpt2_custom_connection):
100
147
101
148
@pytest .mark .skip_if_no_key ("gpt2_custom_connection" )
102
149
def test_open_source_llm_con_model_chat (self , gpt2_custom_connection ):
103
- del gpt2_custom_connection .configs ['model_family' ]
150
+ tmp = copy .deepcopy (gpt2_custom_connection )
151
+ del tmp .configs ['model_family' ]
104
152
with pytest .raises (OpenSourceLLMKeyValidationError ) as exc_info :
105
- os = OpenSourceLLM (gpt2_custom_connection )
153
+ os = OpenSourceLLM (tmp )
106
154
os .call (self .completion_prompt , API .COMPLETION )
107
155
assert exc_info .value .message == """Required key `model_family` not found in given custom connection.
108
156
Required keys are: endpoint_url,model_family."""
@@ -114,7 +162,7 @@ def test_open_source_llm_escape_chat(self):
114
162
assert out_of_danger == "The quick \\ brown fox\\ tjumped\\ \\ over \\ the \\ \\ boy\\ r\\ n"
115
163
116
164
def test_open_source_llm_llama_parse_chat_with_chat (self ):
117
- LlamaContentFormatter .parse_chat (self .llama_chat_prompt )
165
+ LlamaContentFormatter .parse_chat (self .chat_prompt )
118
166
119
167
def test_open_source_llm_llama_parse_multi_turn (self ):
120
168
multi_turn_chat = """user:
@@ -163,8 +211,9 @@ def test_open_source_llm_llama_parse_chat_with_comp(self):
163
211
164
212
@pytest .mark .skip_if_no_key ("gpt2_custom_connection" )
165
213
def test_open_source_llm_llama_endpoint_miss (self , gpt2_custom_connection ):
166
- gpt2_custom_connection .configs ['endpoint_url' ] += 'completely/real/endpoint'
167
- os = OpenSourceLLM (gpt2_custom_connection )
214
+ tmp = copy .deepcopy (gpt2_custom_connection )
215
+ tmp .configs ['endpoint_url' ] += 'completely/real/endpoint'
216
+ os = OpenSourceLLM (tmp )
168
217
with pytest .raises (OpenSourceLLMOnlineEndpointError ) as exc_info :
169
218
os .call (
170
219
self .completion_prompt ,
@@ -175,30 +224,49 @@ def test_open_source_llm_llama_endpoint_miss(self, gpt2_custom_connection):
175
224
assert exc_info .value .error_codes == "UserError/OpenSourceLLMOnlineEndpointError" .split ("/" )
176
225
177
226
@pytest .mark .skip_if_no_key ("gpt2_custom_connection" )
178
- def test_open_source_llm_llama_deployment_miss (self , gpt2_custom_connection ):
179
- os = OpenSourceLLM (
180
- gpt2_custom_connection ,
181
- deployment_name = "completely/real/deployment-007" )
227
+ def test_open_source_llm_llama_deployment_miss (self , gpt2_provider ):
182
228
with pytest .raises (OpenSourceLLMOnlineEndpointError ) as exc_info :
183
- os .call (self .completion_prompt , API .COMPLETION )
229
+ gpt2_provider .call (self .completion_prompt ,
230
+ API .COMPLETION ,
231
+ deployment_name = "completely/real/deployment-007" )
184
232
assert exc_info .value .message == (
185
233
"Exception hit calling Oneline Endpoint: "
186
234
+ "HTTPError: HTTP Error 404: Not Found" )
187
235
assert exc_info .value .error_codes == "UserError/OpenSourceLLMOnlineEndpointError" .split ("/" )
188
236
189
- @pytest .mark .skip
190
- def test_open_source_llm_endpoint_name (self ):
191
- os .environ ["AZUREML_ARM_SUBSCRIPTION" ] = "<needs_value>"
192
- os .environ ["AZUREML_ARM_RESOURCEGROUP" ] = "<needs_value>"
193
- os .environ ["AZUREML_ARM_WORKSPACE_NAME" ] = "<needs_value>"
194
-
195
- os_llm = OpenSourceLLM (endpoint_name = "llama-temp-chat" )
196
- response = os_llm .call (self .llama_chat_prompt , API .CHAT )
197
- assert len (response ) > 25
237
+ @pytest .mark .skip_if_no_key ("open_source_llm_ws_service_connection" )
238
+ def test_open_source_llm_chat_endpoint_name (self , chat_endpoints_provider ):
239
+ for endpoint_name in chat_endpoints_provider :
240
+ os_llm = OpenSourceLLM (endpoint_name = endpoint_name )
241
+ response = os_llm .call (self .chat_prompt , API .CHAT )
242
+ assert len (response ) > 25
243
+
244
+ @pytest .mark .skip_if_no_key ("open_source_llm_ws_service_connection" )
245
+ def test_open_source_llm_chat_endpoint_name_with_deployment (self , chat_endpoints_provider ):
246
+ for endpoint_name in chat_endpoints_provider :
247
+ os_llm = OpenSourceLLM (endpoint_name = endpoint_name )
248
+ for deployment_name in chat_endpoints_provider [endpoint_name ]:
249
+ response = os_llm .call (self .chat_prompt , API .CHAT , deployment_name = deployment_name )
250
+ assert len (response ) > 25
251
+
252
+ @pytest .mark .skip_if_no_key ("open_source_llm_ws_service_connection" )
253
+ def test_open_source_llm_completion_endpoint_name (self , completion_endpoints_provider ):
254
+ for endpoint_name in completion_endpoints_provider :
255
+ os_llm = OpenSourceLLM (endpoint_name = endpoint_name )
256
+ response = os_llm .call (self .completion_prompt , API .COMPLETION )
257
+ assert len (response ) > 25
258
+
259
+ @pytest .mark .skip_if_no_key ("open_source_llm_ws_service_connection" )
260
+ def test_open_source_llm_completion_endpoint_name_with_deployment (self , completion_endpoints_provider ):
261
+ for endpoint_name in completion_endpoints_provider :
262
+ os_llm = OpenSourceLLM (endpoint_name = endpoint_name )
263
+ for deployment_name in completion_endpoints_provider [endpoint_name ]:
264
+ response = os_llm .call (self .completion_prompt , API .COMPLETION , deployment_name = deployment_name )
265
+ assert len (response ) > 25
198
266
199
267
@pytest .mark .skip_if_no_key ("llama_chat_custom_connection" )
200
268
def test_open_source_llm_llama_chat (self , llama_chat_provider ):
201
- response = llama_chat_provider .call (self .llama_chat_prompt , API .CHAT )
269
+ response = llama_chat_provider .call (self .chat_prompt , API .CHAT )
202
270
assert len (response ) > 25
203
271
204
272
@pytest .mark .skip_if_no_key ("llama_chat_custom_connection" )
0 commit comments