diff --git a/docs/media/reference/tools-reference/open_model_llm_on_vscode_promptflow.png b/docs/media/reference/tools-reference/open_model_llm_on_vscode_promptflow.png new file mode 100644 index 00000000000..318c5c16a67 Binary files /dev/null and b/docs/media/reference/tools-reference/open_model_llm_on_vscode_promptflow.png differ diff --git a/docs/media/reference/tools-reference/open_source_llm_on_vscode_promptflow.png b/docs/media/reference/tools-reference/open_source_llm_on_vscode_promptflow.png deleted file mode 100644 index 9e79e237917..00000000000 Binary files a/docs/media/reference/tools-reference/open_source_llm_on_vscode_promptflow.png and /dev/null differ diff --git a/docs/reference/index.md b/docs/reference/index.md index 0ceb9cae01d..598dba9c944 100644 --- a/docs/reference/index.md +++ b/docs/reference/index.md @@ -37,7 +37,7 @@ tools-reference/serp-api-tool tools-reference/faiss_index_lookup_tool tools-reference/vector_db_lookup_tool tools-reference/embedding_tool -tools-reference/open_source_llm_tool +tools-reference/open_model_llm_tool tools-reference/openai-gpt-4v-tool tools-reference/contentsafety_text_tool ``` diff --git a/docs/reference/tools-reference/open_source_llm_tool.md b/docs/reference/tools-reference/open_model_llm_tool.md similarity index 70% rename from docs/reference/tools-reference/open_source_llm_tool.md rename to docs/reference/tools-reference/open_model_llm_tool.md index 6825b5fa242..1c4f4894479 100644 --- a/docs/reference/tools-reference/open_source_llm_tool.md +++ b/docs/reference/tools-reference/open_model_llm_tool.md @@ -1,23 +1,23 @@ -# Open Source LLM +# Open Model LLM ## Introduction -The Open Source LLM tool enables the utilization of a variety of Open Source and Foundational Models, such as [Falcon](https://ml.azure.com/models/tiiuae-falcon-7b/version/4/catalog/registry/azureml) and [Llama 2](https://ml.azure.com/models/Llama-2-7b-chat/version/14/catalog/registry/azureml-meta), for natural language processing in Azure ML Prompt Flow. +The Open Model LLM tool enables the utilization of a variety of Open Model and Foundational Models, such as [Falcon](https://ml.azure.com/models/tiiuae-falcon-7b/version/4/catalog/registry/azureml) and [Llama 2](https://ml.azure.com/models/Llama-2-7b-chat/version/14/catalog/registry/azureml-meta), for natural language processing in Azure ML Prompt Flow. Here's how it looks in action on the Visual Studio Code prompt flow extension. In this example, the tool is being used to call a LlaMa-2 chat endpoint and asking "What is CI?". -![Screenshot of the Open Source LLM On VScode Prompt Flow extension](../../media/reference/tools-reference/open_source_llm_on_vscode_promptflow.png) +![Screenshot of the Open Model LLM On VScode Prompt Flow extension](../../media/reference/tools-reference/open_model_llm_on_vscode_promptflow.png) This prompt flow tool supports two different LLM API types: - **Chat**: Shown in the example above. The chat API type facilitates interactive conversations with text-based inputs and responses. - **Completion**: The Completion API type is used to generate single response text completions based on provided prompt input. -## Quick Overview: How do I use Open Source LLM Tool? +## Quick Overview: How do I use Open Model LLM Tool? 1. Choose a Model from the AzureML Model Catalog and get it deployed. 2. Connect to the model deployment. -3. Configure the open source llm tool settings. +3. Configure the open model llm tool settings. 4. Prepare the Prompt with [guidance](./prompt-tool.md#how-to-write-prompt). 5. Run the flow. @@ -35,15 +35,15 @@ In order for prompt flow to use your deployed model, you will need to connect to ### 1. Endpoint Connections -Once associated to a AzureML or Azure AI Studio workspace, the Open Source LLM tool can use the endpoints on that workspace. +Once associated to a AzureML or Azure AI Studio workspace, the Open Model LLM tool can use the endpoints on that workspace. 1. **Using AzureML or Azure AI Studio workspaces**: If you are using prompt flow in one of the web page based browsers workspaces, the online endpoints available on that workspace will automatically who up. -2. **Using VScode or Code First**: If you are using prompt flow in VScode or one of the Code First offerings, you will need to connect to the workspace. The Open Source LLM tool uses the azure.identity DefaultAzureCredential client for authorization. One way is through [setting environment credential values](https://learn.microsoft.com/en-us/python/api/azure-identity/azure.identity.environmentcredential?view=azure-python). +2. **Using VScode or Code First**: If you are using prompt flow in VScode or one of the Code First offerings, you will need to connect to the workspace. The Open Model LLM tool uses the azure.identity DefaultAzureCredential client for authorization. One way is through [setting environment credential values](https://learn.microsoft.com/en-us/python/api/azure-identity/azure.identity.environmentcredential?view=azure-python). ### 2. Custom Connections -The Open Source LLM tool uses the CustomConnection. Prompt flow supports two types of connections: +The Open Model LLM tool uses the CustomConnection. Prompt flow supports two types of connections: 1. **Workspace Connections** - These are connections which are stored as secrets on an Azure Machine Learning workspace. While these can be used, in many places, the are commonly created and maintained in the Studio UI. @@ -64,7 +64,7 @@ The required keys to set are: ## Running the Tool: Inputs -The Open Source LLM tool has a number of parameters, some of which are required. Please see the below table for details, you can match these to the screen shot above for visual clarity. +The Open Model LLM tool has a number of parameters, some of which are required. Please see the below table for details, you can match these to the screen shot above for visual clarity. | Name | Type | Description | Required | |------|------|-------------|----------| @@ -83,3 +83,7 @@ The Open Source LLM tool has a number of parameters, some of which are required. |------------|-------------|------------------------------------------| | Completion | string | The text of one predicted completion | | Chat | string | The text of one response int the conversation | + +## Deploying to an Online Endpoint + +When deploying a flow containing the Open Model LLM tool to an online endpoint, there is an additional step to setup permissions. During deployment through the web pages, there is a choice between System-assigned and User-assigned Identity types. Either way, using the Azure Portal (or a similar functionality), add the "Reader" Job function role to the identity on the Azure Machine Learning workspace or Ai Studio project which is hosting the endpoint. The prompt flow deployment may need to be refreshed. diff --git a/src/promptflow-tools/promptflow/tools/exception.py b/src/promptflow-tools/promptflow/tools/exception.py index b9850072105..4bcdab2ee41 100644 --- a/src/promptflow-tools/promptflow/tools/exception.py +++ b/src/promptflow-tools/promptflow/tools/exception.py @@ -144,21 +144,21 @@ def __init__(self, **kwargs): super().__init__(**kwargs, target=ErrorTarget.TOOL) -class OpenSourceLLMOnlineEndpointError(UserErrorException): +class OpenModelLLMOnlineEndpointError(UserErrorException): """Base exception raised when the call to an online endpoint failed.""" def __init__(self, **kwargs): super().__init__(**kwargs, target=ErrorTarget.TOOL) -class OpenSourceLLMUserError(UserErrorException): - """Base exception raised when the call to Open Source LLM failed with a user error.""" +class OpenModelLLMUserError(UserErrorException): + """Base exception raised when the call to Open Model LLM failed with a user error.""" def __init__(self, **kwargs): super().__init__(**kwargs, target=ErrorTarget.TOOL) -class OpenSourceLLMKeyValidationError(ToolValidationError): +class OpenModelLLMKeyValidationError(ToolValidationError): """Base exception raised when failed to validate functions when call chat api.""" def __init__(self, **kwargs): diff --git a/src/promptflow-tools/promptflow/tools/open_source_llm.py b/src/promptflow-tools/promptflow/tools/open_model_llm.py similarity index 92% rename from src/promptflow-tools/promptflow/tools/open_source_llm.py rename to src/promptflow-tools/promptflow/tools/open_model_llm.py index 9515a510400..84c5c28a0e0 100644 --- a/src/promptflow-tools/promptflow/tools/open_source_llm.py +++ b/src/promptflow-tools/promptflow/tools/open_model_llm.py @@ -18,9 +18,9 @@ from promptflow.contracts.types import PromptTemplate from promptflow.tools.common import render_jinja_template, validate_role from promptflow.tools.exception import ( - OpenSourceLLMOnlineEndpointError, - OpenSourceLLMUserError, - OpenSourceLLMKeyValidationError, + OpenModelLLMOnlineEndpointError, + OpenModelLLMUserError, + OpenModelLLMKeyValidationError, ChatAPIInvalidRole ) @@ -44,11 +44,11 @@ def wrapper(*args, **kwargs): for i in range(max_retries): try: return func(*args, **kwargs) - except OpenSourceLLMOnlineEndpointError as e: + except OpenModelLLMOnlineEndpointError as e: if i == max_retries - 1: error_message = f"Exception hit calling Online Endpoint: {type(e).__name__}: {str(e)}" print(error_message, file=sys.stderr) - raise OpenSourceLLMOnlineEndpointError(message=error_message) + raise OpenModelLLMOnlineEndpointError(message=error_message) delay *= exponential_base time.sleep(delay) @@ -93,11 +93,11 @@ class Endpoint: def __init__(self, endpoint_name: str, endpoint_url: str, - endpoint_key: str): + endpoint_api_key: str): self.deployments: List[Deployment] = [] self.default_deployment: Deployment = None self.endpoint_url = endpoint_url - self.endpoint_key = endpoint_key + self.endpoint_api_key = endpoint_api_key self.endpoint_name = endpoint_name @@ -221,13 +221,13 @@ def get_serverless_endpoint_key(self, serverless_endpoint_name) endpoint_url = try_get_from_dict(endpoint, ['properties', 'inferenceEndpoint', 'uri']) model_family = self._validate_model_family(endpoint) - endpoint_key = self._list_endpoint_key(token, - subscription_id, - resource_group, - workspace_name, - serverless_endpoint_name)['primaryKey'] + endpoint_api_key = self._list_endpoint_key(token, + subscription_id, + resource_group, + workspace_name, + serverless_endpoint_name)['primaryKey'] return (endpoint_url, - endpoint_key, + endpoint_api_key, model_family) @@ -334,7 +334,7 @@ def get_endpoint_from_custom_connection(self, connection: CustomConnection) -> T for key in REQUIRED_CONFIG_KEYS: if key not in conn_dict: accepted_keys = ",".join([key for key in REQUIRED_CONFIG_KEYS]) - raise OpenSourceLLMKeyValidationError( + raise OpenModelLLMKeyValidationError( message=f"""Required key `{key}` not found in given custom connection. Required keys are: {accepted_keys}.""" ) @@ -342,7 +342,7 @@ def get_endpoint_from_custom_connection(self, connection: CustomConnection) -> T for key in REQUIRED_SECRET_KEYS: if key not in conn_dict: accepted_keys = ",".join([key for key in REQUIRED_SECRET_KEYS]) - raise OpenSourceLLMKeyValidationError( + raise OpenModelLLMKeyValidationError( message=f"""Required secret key `{key}` not found in given custom connection. Required keys are: {accepted_keys}.""" ) @@ -389,7 +389,7 @@ def get_ml_client(self, message = "Unable to connect to AzureML. Please ensure the following environment variables are set: " message += ",".join(ENDPOINT_REQUIRED_ENV_VARS) message += "\nException: " + str(e) - raise OpenSourceLLMOnlineEndpointError(message=message) + raise OpenModelLLMOnlineEndpointError(message=message) def get_endpoints_and_deployments(self, credential, @@ -404,7 +404,7 @@ def get_endpoints_and_deployments(self, endpoint = Endpoint( endpoint_name=ep.name, endpoint_url=ep.scoring_uri, - endpoint_key=ml_client.online_endpoints.get_keys(ep.name).primary_key) + endpoint_api_key=ml_client.online_endpoints.get_keys(ep.name).primary_key) ordered_deployment_names = sorted(ep.traffic, key=lambda item: item[1]) deployments = ml_client.online_deployments.list(ep.name) @@ -653,7 +653,7 @@ def validate_model_family(model_family: str): return ModelFamily[model_family] except KeyError: accepted_models = ",".join([model.name for model in ModelFamily]) - raise OpenSourceLLMKeyValidationError( + raise OpenModelLLMKeyValidationError( message=f"""Given model_family '{model_family}' not recognized. Supported models are: {accepted_models}.""" ) @@ -718,11 +718,11 @@ def parse_chat(chat_str: str) -> List[Dict[str, str]]: try: validate_role(role, VALID_LLAMA_ROLES) except ChatAPIInvalidRole as e: - raise OpenSourceLLMUserError(message=e.message) + raise OpenModelLLMUserError(message=e.message) if len(chunks) <= index + 1: message = "Unexpected chat format. Please ensure the query matches the chat format of the model used." - raise OpenSourceLLMUserError(message=message) + raise OpenModelLLMUserError(message=message) chat_list.append({ "role": role, @@ -812,7 +812,7 @@ def format_response_payload(self, output: bytes) -> str: else: error_message = f"Unexpected response format. Response: {response_json}" print(error_message, file=sys.stderr) - raise OpenSourceLLMOnlineEndpointError(message=error_message) + raise OpenModelLLMOnlineEndpointError(message=error_message) class ServerlessLlamaContentFormatter(ContentFormatterBase): @@ -857,7 +857,7 @@ def format_response_payload(self, output: bytes) -> str: else: error_message = f"Unexpected response format. Response: {response_json}" print(error_message, file=sys.stderr) - raise OpenSourceLLMOnlineEndpointError(message=error_message) + raise OpenModelLLMOnlineEndpointError(message=error_message) class ContentFormatterFactory: @@ -916,7 +916,7 @@ def _call_endpoint(self, request_body: str) -> str: headers = { "Content-Type": "application/json", "Authorization": ("Bearer " + self.endpoint_api_key), - "x-ms-user-agent": "PromptFlow/OpenSourceLLM/" + self.model_family + "x-ms-user-agent": "PromptFlow/OpenModelLLM/" + self.model_family } # If this is not set it'll use the default deployment on the endpoint. @@ -928,7 +928,7 @@ def _call_endpoint(self, request_body: str) -> str: error_message = f"""Request failure while calling Online Endpoint Status:{result.status_code} Error:{result.text}""" print(error_message, file=sys.stderr) - raise OpenSourceLLMOnlineEndpointError(message=error_message) + raise OpenModelLLMOnlineEndpointError(message=error_message) return result.text @@ -953,7 +953,7 @@ def __call__( return response -class OpenSourceLLM(ToolProvider): +class OpenModelLLM(ToolProvider): def __init__(self): super().__init__() @@ -976,18 +976,18 @@ def get_deployment_from_endpoint(self, if ep.endpoint_name == endpoint_name: if deployment_name is None: return (ep.endpoint_url, - ep.endpoint_key, + ep.endpoint_api_key, ep.default_deployment.model_family) for d in ep.deployments: if d.deployment_name == deployment_name: return (ep.endpoint_url, - ep.endpoint_key, + ep.endpoint_api_key, d.model_family) message = f"""Invalid endpoint and deployment values. Please ensure endpoint name and deployment names are correct, and the deployment was successfull. Could not find endpoint: {endpoint_name} and deployment: {deployment_name}""" - raise OpenSourceLLMUserError(message=message) + raise OpenModelLLMUserError(message=message) def sanitize_endpoint_url(self, endpoint_url: str, @@ -1011,16 +1011,16 @@ def get_endpoint_details(self, deployment_name: str = None, **kwargs) -> Tuple[str, str, str]: if self.endpoint_values_in_kwargs(**kwargs): - endpoint_uri = kwargs["endpoint_uri"] - endpoint_key = kwargs["endpoint_key"] + endpoint_url = kwargs["endpoint_url"] + endpoint_api_key = kwargs["endpoint_api_key"] model_family = kwargs["model_family"] # clean these up, aka don't send them to MIR - del kwargs["endpoint_uri"] - del kwargs["endpoint_key"] + del kwargs["endpoint_url"] + del kwargs["endpoint_api_key"] del kwargs["model_family"] - return (endpoint_uri, endpoint_key, model_family) + return (endpoint_url, endpoint_api_key, model_family) (endpoint_connection_type, endpoint_connection_name) = parse_endpoint_connection_type(endpoint) print(f"endpoint_connection_type: {endpoint_connection_type} name: {endpoint_connection_name}", file=sys.stdout) @@ -1035,25 +1035,26 @@ def get_endpoint_details(self, message = f"""Error encountered while attempting to Authorize access to {endpoint}. Exception: {e}""" print(message, file=sys.stderr) - raise OpenSourceLLMUserError(message=message) + raise OpenModelLLMUserError(message=message) if con_type == "serverlessendpoint": - (endpoint_url, endpoint_key, model_family) = SERVERLESS_ENDPOINT_CONTAINER.get_serverless_endpoint_key( + (endpoint_url, endpoint_api_key, model_family) = SERVERLESS_ENDPOINT_CONTAINER.get_serverless_endpoint_key( token, subscription_id, resource_group_name, workspace_name, endpoint_connection_name) elif con_type == "onlineendpoint": - (endpoint_url, endpoint_key, model_family) = self.get_deployment_from_endpoint(credential, - subscription_id, - resource_group_name, - workspace_name, - endpoint_connection_name, - deployment_name) + (endpoint_url, endpoint_api_key, model_family) = self.get_deployment_from_endpoint( + credential, + subscription_id, + resource_group_name, + workspace_name, + endpoint_connection_name, + deployment_name) elif con_type == "connection": (endpoint_url, - endpoint_key, + endpoint_api_key, model_family) = CUSTOM_CONNECTION_CONTAINER.get_endpoint_from_azure_custom_connection( credential, subscription_id, @@ -1062,22 +1063,22 @@ def get_endpoint_details(self, endpoint_connection_name) elif con_type == "localconnection": (endpoint_url, - endpoint_key, + endpoint_api_key, model_family) = CUSTOM_CONNECTION_CONTAINER.get_endpoint_from_local_custom_connection( endpoint_connection_name) else: - raise OpenSourceLLMUserError(message=f"Invalid endpoint connection type: {endpoint_connection_type}") - return (self.sanitize_endpoint_url(endpoint_url, api_type), endpoint_key, model_family) + raise OpenModelLLMUserError(message=f"Invalid endpoint connection type: {endpoint_connection_type}") + return (self.sanitize_endpoint_url(endpoint_url, api_type), endpoint_api_key, model_family) def endpoint_values_in_kwargs(self, **kwargs): # This is mostly for testing, suggest not using this since security\privacy concerns for the endpoint key - if 'endpoint_uri' not in kwargs and 'endpoint_key' not in kwargs and 'model_family' not in kwargs: + if 'endpoint_url' not in kwargs and 'endpoint_api_key' not in kwargs and 'model_family' not in kwargs: return False - if 'endpoint_uri' not in kwargs or 'endpoint_key' not in kwargs or 'model_family' not in kwargs: + if 'endpoint_url' not in kwargs or 'endpoint_api_key' not in kwargs or 'model_family' not in kwargs: message = """Endpoint connection via kwargs not fully set. -If using kwargs, the following values must be set: endpoint_uri, endpoint_key, and model_family""" - raise OpenSourceLLMKeyValidationError(message=message) +If using kwargs, the following values must be set: endpoint_url, endpoint_api_key, and model_family""" + raise OpenModelLLMKeyValidationError(message=message) return True @@ -1101,10 +1102,10 @@ def call( if not deployment_name or deployment_name == DEPLOYMENT_DEFAULT: deployment_name = None - print(f"Executing Open Source LLM Tool for endpoint: '{endpoint_name}', deployment: '{deployment_name}'", + print(f"Executing Open Model LLM Tool for endpoint: '{endpoint_name}', deployment: '{deployment_name}'", file=sys.stdout) - (endpoint_uri, endpoint_key, model_family) = self.get_endpoint_details( + (endpoint_url, endpoint_api_key, model_family) = self.get_endpoint_details( subscription_id=os.getenv("AZUREML_ARM_SUBSCRIPTION", None), resource_group_name=os.getenv("AZUREML_ARM_RESOURCEGROUP", None), workspace_name=os.getenv("AZUREML_ARM_WORKSPACE_NAME", None), @@ -1123,12 +1124,12 @@ def call( model_family=model_family, api=api, chat_history=prompt, - endpoint_url=endpoint_uri + endpoint_url=endpoint_url ) llm = AzureMLOnlineEndpoint( - endpoint_url=endpoint_uri, - endpoint_api_key=endpoint_key, + endpoint_url=endpoint_url, + endpoint_api_key=endpoint_api_key, model_family=model_family, content_formatter=content_formatter, deployment_name=deployment_name, diff --git a/src/promptflow-tools/promptflow/tools/yamls/open_source_llm.yaml b/src/promptflow-tools/promptflow/tools/yamls/open_model_llm.yaml similarity index 80% rename from src/promptflow-tools/promptflow/tools/yamls/open_source_llm.yaml rename to src/promptflow-tools/promptflow/tools/yamls/open_model_llm.yaml index 61a942e2eff..05ae6c71165 100644 --- a/src/promptflow-tools/promptflow/tools/yamls/open_source_llm.yaml +++ b/src/promptflow-tools/promptflow/tools/yamls/open_model_llm.yaml @@ -1,17 +1,17 @@ -promptflow.tools.open_source_llm.OpenSourceLLM.call: - name: Open Source LLM - description: Use an Open Source model from the Azure Model catalog, deployed to an AzureML Online Endpoint for LLM Chat or Completion API calls. +promptflow.tools.open_model_llm.OpenModelLLM.call: + name: Open Model LLM + description: Use an open model from the Azure Model catalog, deployed to an AzureML Online Endpoint for LLM Chat or Completion API calls. icon: data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAACgElEQVR4nGWSz2vcVRTFP/e9NzOZ1KDGohASslLEH6VLV0ak4l/QpeDCrfQPcNGliODKnVm4EBdBsIjQIlhciKW0ycKFVCSNbYnjdDLtmPnmO/nO9917XcxMkjYX3uLx7nnn3HOuMK2Nix4fP78ZdrYXVkLVWjf3l3B1B+HpcjzGFtmqa6cePz7/x0dnn1n5qhj3iBJPYREIURAJuCtpY8PjReDbrf9WG7H1fuefwQU9qKztTcMJT+PNnEFvjGVDBDlSsH6p/9MLzy6+NxwVqI8RAg4IPmWedMckdLYP6O6UpIaQfvyyXG012+e79/ZfHukoS1ISMT2hGTB1RkUmNgQ5QZ0w+a2VWDq73MbdEWmfnnv6UWe7oNzPaLapl5CwuLTXK9WUGBuCjqekzhP+z52ZXOrKMD3OJg0Hh778aiOuvpnYvp05d6GJO4iAO4QAe/eV36/X5LFRV4Zmn+AdkqlL8Vjp3oVioOz+WTPzzYEgsN+fgPLYyJVheSbPPVl2ikeGZRjtG52/8rHuaV9VOlpP2OtKyVndcRVCSqOhsvxa4vW359i6OuKdD+aP8Q4SYPdOzS/flGjt1JUSaMqZ5nwa1Y8qWb/Ud/eZZkHisYezEM0m+fcelDr8F1SqW2LNK6r1jXQwyLzy1hxvrLXZulry7ocL+FS6G4QIu3fG/Px1gdYeW7LIgXU2P/115TOA5G7e3Rmj2aS/m7l5pThiZzrCcE/d1XHzbln373nw7y6veeoUm5KCNKT/IPPwbiY1hYd/l5MIT65BMFt87sU4v9D7/JMflr44uV6hGh1+L4RCkg6z5iK2tAhNLeLsNGwYA4fDYnC/drvuuFxe86NV/x+Ut27g0FvykgAAAABJRU5ErkJggg== type: custom_llm - module: promptflow.tools.open_source_llm - class_name: OpenSourceLLM + module: promptflow.tools.open_model_llm + class_name: OpenModelLLM function: call inputs: endpoint_name: type: - string dynamic_list: - func_path: promptflow.tools.open_source_llm.list_endpoint_names + func_path: promptflow.tools.open_model_llm.list_endpoint_names allow_manual_entry: true # Allow the user to clear this field is_multi_select: false deployment_name: @@ -19,7 +19,7 @@ promptflow.tools.open_source_llm.OpenSourceLLM.call: type: - string dynamic_list: - func_path: promptflow.tools.open_source_llm.list_deployment_names + func_path: promptflow.tools.open_model_llm.list_deployment_names func_kwargs: - name: endpoint type: diff --git a/src/promptflow-tools/tests/conftest.py b/src/promptflow-tools/tests/conftest.py index cec414bad16..725e0c784c3 100644 --- a/src/promptflow-tools/tests/conftest.py +++ b/src/promptflow-tools/tests/conftest.py @@ -48,7 +48,7 @@ def serp_connection(): return ConnectionManager().get("serp_connection") -def verify_oss_llm_custom_connection(connection: CustomConnection) -> bool: +def verify_om_llm_custom_connection(connection: CustomConnection) -> bool: '''Verify that there is a MIR endpoint up and available for the Custom Connection. We explicitly do not pass the endpoint key to avoid the delay in generating a response. ''' @@ -66,7 +66,7 @@ def llama_chat_custom_connection(): @pytest.fixture -def open_source_llm_ws_service_connection() -> bool: +def open_model_llm_ws_service_connection() -> bool: try: creds_custom_connection: CustomConnection = ConnectionManager().get("open_source_llm_ws_service_connection") subs = json.loads(creds_custom_connection.secrets['service_credential']) @@ -92,9 +92,9 @@ def skip_if_no_api_key(request, mocker): elif isinstance(connection, CustomConnection): if "endpoint_api_key" not in connection.secrets or "-api-key" in connection.secrets["endpoint_api_key"]: pytest.skip('skipped because no key') - # Verify Custom Connections, but only those used by the Open_Source_LLM Tool + # Verify Custom Connections, but only those used by the Open_Model_LLM Tool if "endpoint_url" in connection.configs and "-endpoint-url" not in connection.configs["endpoint_url"]: - if not verify_oss_llm_custom_connection(connection): + if not verify_om_llm_custom_connection(connection): pytest.skip('skipped because the connection is not valid') diff --git a/src/promptflow-tools/tests/test_open_source_llm.py b/src/promptflow-tools/tests/test_open_model_llm.py similarity index 86% rename from src/promptflow-tools/tests/test_open_source_llm.py rename to src/promptflow-tools/tests/test_open_model_llm.py index 19a3fec22b7..d14727b6ebc 100644 --- a/src/promptflow-tools/tests/test_open_source_llm.py +++ b/src/promptflow-tools/tests/test_open_model_llm.py @@ -6,11 +6,11 @@ from typing import List, Dict from promptflow.tools.exception import ( - OpenSourceLLMUserError, - OpenSourceLLMKeyValidationError + OpenModelLLMUserError, + OpenModelLLMKeyValidationError ) -from promptflow.tools.open_source_llm import ( - OpenSourceLLM, +from promptflow.tools.open_model_llm import ( + OpenModelLLM, API, ContentFormatterBase, LlamaContentFormatter, @@ -28,7 +28,7 @@ def validate_response(response): def verify_prompt_role_delimiters(message: str, codes: List[str]): - assert codes == "UserError/OpenSourceLLMUserError".split("/") + assert codes == "UserError/OpenModelLLMUserError".split("/") message_pattern = re.compile( r"The Chat API requires a specific format for prompt definition, and the prompt should include separate " @@ -42,10 +42,10 @@ def verify_prompt_role_delimiters(message: str, codes: List[str]): @pytest.fixture -def verify_service_endpoints(open_source_llm_ws_service_connection) -> Dict[str, List[str]]: - if not open_source_llm_ws_service_connection: +def verify_service_endpoints(open_model_llm_ws_service_connection) -> Dict[str, List[str]]: + if not open_model_llm_ws_service_connection: pytest.skip("Service Credential not available") - print("open_source_llm_ws_service_connection completed") + print("open_model_llm_ws_service_connection completed") required_env_vars = ["AZUREML_ARM_SUBSCRIPTION", "AZUREML_ARM_RESOURCEGROUP", "AZUREML_ARM_WORKSPACE_NAME", "AZURE_CLIENT_ID", "AZURE_TENANT_ID", "AZURE_CLIENT_SECRET"] for rev in required_env_vars: @@ -103,8 +103,8 @@ def completion_endpoints_provider(endpoints_provider: Dict[str, List[str]]) -> D @pytest.mark.usefixtures("use_secrets_config_file") -class TestOpenSourceLLM: - stateless_os_llm = OpenSourceLLM() +class TestOpenModelLLM: + stateless_os_llm = OpenModelLLM() gpt2_connection = "connection/gpt2_connection" llama_connection = "connection/llama_chat_connection" llama_serverless_connection = "connection/llama_chat_serverless" @@ -118,14 +118,14 @@ class TestOpenSourceLLM: user: """ + completion_prompt - def test_open_source_llm_completion(self, verify_service_endpoints): + def test_open_model_llm_completion(self, verify_service_endpoints): response = self.stateless_os_llm.call( self.completion_prompt, API.COMPLETION, endpoint_name=self.gpt2_connection) validate_response(response) - def test_open_source_llm_completion_with_deploy(self, verify_service_endpoints): + def test_open_model_llm_completion_with_deploy(self, verify_service_endpoints): response = self.stateless_os_llm.call( self.completion_prompt, API.COMPLETION, @@ -133,14 +133,14 @@ def test_open_source_llm_completion_with_deploy(self, verify_service_endpoints): deployment_name="gpt2-10") validate_response(response) - def test_open_source_llm_chat(self, verify_service_endpoints): + def test_open_model_llm_chat(self, verify_service_endpoints): response = self.stateless_os_llm.call( self.chat_prompt, API.CHAT, endpoint_name=self.gpt2_connection) validate_response(response) - def test_open_source_llm_chat_with_deploy(self, verify_service_endpoints): + def test_open_model_llm_chat_with_deploy(self, verify_service_endpoints): response = self.stateless_os_llm.call( self.chat_prompt, API.CHAT, @@ -148,7 +148,7 @@ def test_open_source_llm_chat_with_deploy(self, verify_service_endpoints): deployment_name="gpt2-10") validate_response(response) - def test_open_source_llm_chat_with_max_length(self, verify_service_endpoints): + def test_open_model_llm_chat_with_max_length(self, verify_service_endpoints): response = self.stateless_os_llm.call( self.chat_prompt, API.CHAT, @@ -158,49 +158,49 @@ def test_open_source_llm_chat_with_max_length(self, verify_service_endpoints): validate_response(response) @pytest.mark.skip_if_no_api_key("gpt2_custom_connection") - def test_open_source_llm_con_url_chat(self, gpt2_custom_connection): + def test_open_model_llm_con_url_chat(self, gpt2_custom_connection): tmp = copy.deepcopy(gpt2_custom_connection) del tmp.configs['endpoint_url'] - with pytest.raises(OpenSourceLLMKeyValidationError) as exc_info: + with pytest.raises(OpenModelLLMKeyValidationError) as exc_info: customConnectionsContainer = CustomConnectionsContainer() customConnectionsContainer.get_endpoint_from_custom_connection(connection=tmp) assert exc_info.value.message == """Required key `endpoint_url` not found in given custom connection. Required keys are: endpoint_url,model_family.""" - assert exc_info.value.error_codes == "UserError/ToolValidationError/OpenSourceLLMKeyValidationError".split("/") + assert exc_info.value.error_codes == "UserError/ToolValidationError/OpenModelLLMKeyValidationError".split("/") @pytest.mark.skip_if_no_api_key("gpt2_custom_connection") - def test_open_source_llm_con_key_chat(self, gpt2_custom_connection): + def test_open_model_llm_con_key_chat(self, gpt2_custom_connection): tmp = copy.deepcopy(gpt2_custom_connection) del tmp.secrets['endpoint_api_key'] - with pytest.raises(OpenSourceLLMKeyValidationError) as exc_info: + with pytest.raises(OpenModelLLMKeyValidationError) as exc_info: customConnectionsContainer = CustomConnectionsContainer() customConnectionsContainer.get_endpoint_from_custom_connection(connection=tmp) assert exc_info.value.message == ( "Required secret key `endpoint_api_key` " + """not found in given custom connection. Required keys are: endpoint_api_key.""") - assert exc_info.value.error_codes == "UserError/ToolValidationError/OpenSourceLLMKeyValidationError".split("/") + assert exc_info.value.error_codes == "UserError/ToolValidationError/OpenModelLLMKeyValidationError".split("/") @pytest.mark.skip_if_no_api_key("gpt2_custom_connection") - def test_open_source_llm_con_model_chat(self, gpt2_custom_connection): + def test_open_model_llm_con_model_chat(self, gpt2_custom_connection): tmp = copy.deepcopy(gpt2_custom_connection) del tmp.configs['model_family'] - with pytest.raises(OpenSourceLLMKeyValidationError) as exc_info: + with pytest.raises(OpenModelLLMKeyValidationError) as exc_info: customConnectionsContainer = CustomConnectionsContainer() customConnectionsContainer.get_endpoint_from_custom_connection(connection=tmp) assert exc_info.value.message == """Required key `model_family` not found in given custom connection. Required keys are: endpoint_url,model_family.""" - assert exc_info.value.error_codes == "UserError/ToolValidationError/OpenSourceLLMKeyValidationError".split("/") + assert exc_info.value.error_codes == "UserError/ToolValidationError/OpenModelLLMKeyValidationError".split("/") - def test_open_source_llm_escape_chat(self): + def test_open_model_llm_escape_chat(self): danger = r"The quick \brown fox\tjumped\\over \the \\boy\r\n" out_of_danger = ContentFormatterBase.escape_special_characters(danger) assert out_of_danger == "The quick \\brown fox\\tjumped\\\\over \\the \\\\boy\\r\\n" - def test_open_source_llm_llama_parse_chat_with_chat(self): + def test_open_model_llm_llama_parse_chat_with_chat(self): LlamaContentFormatter.parse_chat(self.chat_prompt) - def test_open_source_llm_llama_parse_multi_turn(self): + def test_open_model_llm_llama_parse_multi_turn(self): multi_turn_chat = """user: You are a AI which helps Customers answer questions. @@ -214,7 +214,7 @@ def test_open_source_llm_llama_parse_multi_turn(self): """ LlamaContentFormatter.parse_chat(multi_turn_chat) - def test_open_source_llm_llama_parse_ignore_whitespace(self): + def test_open_model_llm_llama_parse_ignore_whitespace(self): bad_chat_prompt = f"""system: You are a AI which helps Customers answer questions. @@ -222,16 +222,16 @@ def test_open_source_llm_llama_parse_ignore_whitespace(self): user: {self.completion_prompt}""" - with pytest.raises(OpenSourceLLMUserError) as exc_info: + with pytest.raises(OpenModelLLMUserError) as exc_info: LlamaContentFormatter.parse_chat(bad_chat_prompt) verify_prompt_role_delimiters(exc_info.value.message, exc_info.value.error_codes) - def test_open_source_llm_llama_parse_chat_with_comp(self): - with pytest.raises(OpenSourceLLMUserError) as exc_info: + def test_open_model_llm_llama_parse_chat_with_comp(self): + with pytest.raises(OpenModelLLMUserError) as exc_info: LlamaContentFormatter.parse_chat(self.completion_prompt) verify_prompt_role_delimiters(exc_info.value.message, exc_info.value.error_codes) - def test_open_source_llm_chat_endpoint_name(self, chat_endpoints_provider): + def test_open_model_llm_chat_endpoint_name(self, chat_endpoints_provider): for endpoint_name in chat_endpoints_provider: response = self.stateless_os_llm.call( self.chat_prompt, @@ -239,7 +239,7 @@ def test_open_source_llm_chat_endpoint_name(self, chat_endpoints_provider): endpoint_name=f"onlineEndpoint/{endpoint_name}") validate_response(response) - def test_open_source_llm_chat_endpoint_name_with_deployment(self, chat_endpoints_provider): + def test_open_model_llm_chat_endpoint_name_with_deployment(self, chat_endpoints_provider): for endpoint_name in chat_endpoints_provider: for deployment_name in chat_endpoints_provider[endpoint_name]: response = self.stateless_os_llm.call( @@ -249,7 +249,7 @@ def test_open_source_llm_chat_endpoint_name_with_deployment(self, chat_endpoints deployment_name=deployment_name) validate_response(response) - def test_open_source_llm_completion_endpoint_name(self, completion_endpoints_provider): + def test_open_model_llm_completion_endpoint_name(self, completion_endpoints_provider): for endpoint_name in completion_endpoints_provider: response = self.stateless_os_llm.call( self.completion_prompt, @@ -257,7 +257,7 @@ def test_open_source_llm_completion_endpoint_name(self, completion_endpoints_pro endpoint_name=f"onlineEndpoint/{endpoint_name}") validate_response(response) - def test_open_source_llm_completion_endpoint_name_with_deployment(self, completion_endpoints_provider): + def test_open_model_llm_completion_endpoint_name_with_deployment(self, completion_endpoints_provider): for endpoint_name in completion_endpoints_provider: for deployment_name in completion_endpoints_provider[endpoint_name]: response = self.stateless_os_llm.call( @@ -267,18 +267,18 @@ def test_open_source_llm_completion_endpoint_name_with_deployment(self, completi deployment_name=deployment_name) validate_response(response) - def test_open_source_llm_llama_chat(self, verify_service_endpoints): + def test_open_model_llm_llama_chat(self, verify_service_endpoints): response = self.stateless_os_llm.call(self.chat_prompt, API.CHAT, endpoint_name=self.llama_connection) validate_response(response) - def test_open_source_llm_llama_serverless(self, verify_service_endpoints): + def test_open_model_llm_llama_serverless(self, verify_service_endpoints): response = self.stateless_os_llm.call( self.chat_prompt, API.CHAT, endpoint_name=self.llama_serverless_connection) validate_response(response) - def test_open_source_llm_llama_chat_history(self, verify_service_endpoints): + def test_open_model_llm_llama_chat_history(self, verify_service_endpoints): chat_history_prompt = """system: * Given the following conversation history and the users next question, answer the next question. * If the conversation is irrelevant or empty, acknowledge and ask for more input. @@ -330,7 +330,7 @@ def test_open_source_llm_llama_chat_history(self, verify_service_endpoints): chat_input="Sorry I didn't follow, could you say that again?") validate_response(response) - def test_open_source_llm_dynamic_list_ignore_deployment(self, verify_service_endpoints): + def test_open_model_llm_dynamic_list_ignore_deployment(self, verify_service_endpoints): deployments = list_deployment_names( subscription_id=os.getenv("AZUREML_ARM_SUBSCRIPTION"), resource_group_name=os.getenv("AZUREML_ARM_RESOURCEGROUP"), @@ -355,7 +355,7 @@ def test_open_source_llm_dynamic_list_ignore_deployment(self, verify_service_end assert len(deployments) == 1 assert deployments[0]['value'] == 'default' - def test_open_source_llm_dynamic_list_serverless_test(self, verify_service_endpoints): + def test_open_model_llm_dynamic_list_serverless_test(self, verify_service_endpoints): subscription_id = os.getenv("AZUREML_ARM_SUBSCRIPTION") resource_group_name = os.getenv("AZUREML_ARM_RESOURCEGROUP") workspace_name = os.getenv("AZUREML_ARM_WORKSPACE_NAME") @@ -395,7 +395,7 @@ def test_open_source_llm_dynamic_list_serverless_test(self, verify_service_endpo assert model_family == "LLaMa" assert endpoint_key == eps_keys['primaryKey'] - def test_open_source_llm_dynamic_list_custom_connections_test(self, verify_service_endpoints): + def test_open_model_llm_dynamic_list_custom_connections_test(self, verify_service_endpoints): custom_container = CustomConnectionsContainer() credential = DefaultAzureCredential(exclude_interactive_browser_credential=False) @@ -406,7 +406,7 @@ def test_open_source_llm_dynamic_list_custom_connections_test(self, verify_servi workspace_name=os.getenv("AZUREML_ARM_WORKSPACE_NAME")) assert len(connections) > 1 - def test_open_source_llm_dynamic_list_happy_path(self, verify_service_endpoints): + def test_open_model_llm_dynamic_list_happy_path(self, verify_service_endpoints): endpoints = list_endpoint_names( subscription_id=os.getenv("AZUREML_ARM_SUBSCRIPTION"), resource_group_name=os.getenv("AZUREML_ARM_RESOURCEGROUP"), @@ -466,7 +466,7 @@ def test_open_source_llm_dynamic_list_happy_path(self, verify_service_endpoints) model_kwargs={}) validate_response(response) - def test_open_source_llm_get_model_llama(self): + def test_open_model_llm_get_model_llama(self): model_assets = [ "azureml://registries/azureml-meta/models/Llama-2-7b-chat/versions/14", "azureml://registries/azureml-meta/models/Llama-2-7b/versions/12", @@ -479,7 +479,7 @@ def test_open_source_llm_get_model_llama(self): for asset_name in model_assets: assert ModelFamily.LLAMA == get_model_type(asset_name) - def test_open_source_llm_get_model_gpt2(self): + def test_open_model_llm_get_model_gpt2(self): model_assets = [ "azureml://registries/azureml-staging/models/gpt2/versions/9", "azureml://registries/azureml/models/gpt2/versions/9", @@ -490,7 +490,7 @@ def test_open_source_llm_get_model_gpt2(self): for asset_name in model_assets: assert ModelFamily.GPT2 == get_model_type(asset_name) - def test_open_source_llm_get_model_dolly(self): + def test_open_model_llm_get_model_dolly(self): model_assets = [ "azureml://registries/azureml/models/databricks-dolly-v2-12b/versions/11" ] @@ -498,7 +498,7 @@ def test_open_source_llm_get_model_dolly(self): for asset_name in model_assets: assert ModelFamily.DOLLY == get_model_type(asset_name) - def test_open_source_llm_get_model_falcon(self): + def test_open_model_llm_get_model_falcon(self): model_assets = [ "azureml://registries/azureml/models/tiiuae-falcon-40b/versions/2", "azureml://registries/azureml/models/tiiuae-falcon-40b/versions/2" @@ -507,7 +507,7 @@ def test_open_source_llm_get_model_falcon(self): for asset_name in model_assets: assert ModelFamily.FALCON == get_model_type(asset_name) - def test_open_source_llm_get_model_failure_cases(self): + def test_open_model_llm_get_model_failure_cases(self): bad_model_assets = [ "azureml://registries/azureml-meta/models/CodeLlama-7b-Instruct-hf/versions/3", "azureml://registries/azureml-staging/models/gpt-2/versions/9",