Skip to content

Commit

Permalink
Merge pull request #5 from gjwoods/users/gewoods/open_source_llm/initial
Browse files Browse the repository at this point in the history
Users/gewoods/open source llm/initial
  • Loading branch information
gjwoods authored Sep 26, 2023
2 parents 165dc9e + f0a38ec commit 7331734
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 4 deletions.
1 change: 1 addition & 0 deletions docs/reference/tools-reference/open_source_llm_tool.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ The keys to set are:
| api | string | this will be Completion or Chat, depending on the scenario selected | Yes |
| connection | CustomConnection | the name of the connection which points to the Inferencing endpoint | Yes |
| model_kwargs | dictionary | generic model configuration values, for example temperature | Yes |
| deployment_name | string | the name of the deployment to target on the MIR endpoint. If no value is passed, the MIR load balancer settings will be used. | No |
| prompt | string | text prompt that the language model will complete | Yes |

## Outputs
Expand Down
9 changes: 6 additions & 3 deletions src/promptflow-tools/promptflow/tools/open_source_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,10 +292,12 @@ def __init__(
endpoint_url: str,
endpoint_api_key: str,
content_formatter: ContentFormatterBase,
deployment_name: Optional[str] = None,
model_kwargs: Optional[Dict] = None,
):
self.endpoint_url = endpoint_url
self.endpoint_api_key = endpoint_api_key
self.deployment_name = deployment_name
self.content_formatter = content_formatter
self.model_kwargs = model_kwargs

Expand All @@ -318,9 +320,8 @@ def _call_endpoint(self, body: bytes) -> bytes:
headers = {"Content-Type": "application/json", "Authorization": ("Bearer " + self.endpoint_api_key)}

# If this is not set it'll use the default deployment on the endpoint.
if mir_deployment_name_config in self.model_kwargs:
headers[mir_deployment_name_config] = self.model_kwargs[mir_deployment_name_config]
del self.model_kwargs[mir_deployment_name_config]
if self.deployment_name is not None:
headers["azureml-model-deployment"] = self.deployment_name

req = urllib.request.Request(self.endpoint_url, body, headers)
response = urllib.request.urlopen(req, timeout=50)
Expand Down Expand Up @@ -388,6 +389,7 @@ def call(
self,
prompt: PromptTemplate,
api: API,
deployment_name: Optional[str] = None,
model_kwargs: Optional[Dict] = {},
**kwargs
) -> str:
Expand All @@ -403,6 +405,7 @@ def call(
endpoint_url=self.connection.configs['endpoint_url'],
endpoint_api_key=self.connection.secrets['endpoint_api_key'],
content_formatter=content_formatter,
deployment_name=deployment_name,
model_kwargs=model_kwargs
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ promptflow.tools.open_source_llm.OpenSourceLLM.call:
- completion
type:
- string
deployment_name:
default: null
type:
- string
model_kwargs:
default: "{}"
type:
Expand Down
29 changes: 28 additions & 1 deletion src/promptflow-tools/tests/test_open_source_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,29 @@ def test_open_source_llm_completion(self, gpt2_provider):
API.COMPLETION)
assert len(response) > 25

@pytest.mark.skip_if_no_key("gpt2_custom_connection")
def test_open_source_llm_completion_with_deploy(self, gpt2_provider):
response = gpt2_provider.call(
self.completion_prompt,
API.COMPLETION,
deployment_name="gpt2-8")
assert len(response) > 25

@pytest.mark.skip_if_no_key("gpt2_custom_connection")
def test_open_source_llm_chat(self, gpt2_provider):
response = gpt2_provider.call(
self.chat_prompt,
API.CHAT)
assert len(response) > 25

@pytest.mark.skip_if_no_key("gpt2_custom_connection")
def test_open_source_llm_chat_with_deploy(self, gpt2_provider):
response = gpt2_provider.call(
self.chat_prompt,
API.CHAT,
deployment_name="gpt2-8")
assert len(response) > 25

@pytest.mark.skip_if_no_key("gpt2_custom_connection")
def test_open_source_llm_con_url_chat(self, gpt2_custom_connection):
del gpt2_custom_connection.configs['endpoint_url']
Expand Down Expand Up @@ -121,7 +137,7 @@ def test_open_source_llm_llama_parse_chat_with_comp(self):
pass

@pytest.mark.skip_if_no_key("gpt2_custom_connection")
def test_open_source_llm_llama_req_chat(self, gpt2_custom_connection):
def test_open_source_llm_llama_endpoint_miss(self, gpt2_custom_connection):
gpt2_custom_connection.configs['endpoint_url'] += 'completely/real/endpoint'
os = OpenSourceLLM(gpt2_custom_connection)
try:
Expand All @@ -130,3 +146,14 @@ def test_open_source_llm_llama_req_chat(self, gpt2_custom_connection):
API.COMPLETION)
except OpenSourceLLMOnlineEndpointError:
pass

@pytest.mark.skip_if_no_key("gpt2_custom_connection")
def test_open_source_llm_llama_deployment_miss(self, gpt2_custom_connection):
os = OpenSourceLLM(gpt2_custom_connection)
try:
os.call(
self.completion_prompt,
API.COMPLETION,
deployment_name="completely/real/deployment-007")
except OpenSourceLLMOnlineEndpointError:
pass

0 comments on commit 7331734

Please sign in to comment.