From 8ffe4cb999301f3ac3d6ce44109a95fc52013c50 Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Tue, 11 Jul 2023 07:44:13 +0000 Subject: [PATCH 01/15] Add remote inference integration tests for OpenAI and Cohere Signed-off-by: Ryan Bogan --- .github/workflows/CI-workflow.yml | 3 + .../ml/rest/RestMLRemoteInferenceIT.java | 746 ++++++++++++++++++ 2 files changed, 749 insertions(+) create mode 100644 plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java diff --git a/.github/workflows/CI-workflow.yml b/.github/workflows/CI-workflow.yml index d9fc74c080..fd71212c89 100644 --- a/.github/workflows/CI-workflow.yml +++ b/.github/workflows/CI-workflow.yml @@ -1,4 +1,7 @@ name: Build and Test ml-commons +env: + OPENAI_KEY: ${{ secrets.OPENAI_KEY }} + COHERE_KEY: ${{ secrets.COHERE_KEY }} # This workflow is triggered on pull requests and push to any branches on: pull_request: diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java new file mode 100644 index 0000000000..20c6cc3482 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java @@ -0,0 +1,746 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import org.apache.http.HttpEntity; +import org.apache.http.HttpHeaders; +import org.apache.http.message.BasicHeader; +import org.junit.After; +import org.junit.Before; +import org.junit.Ignore; +import org.junit.Rule; +import org.junit.rules.ExpectedException; +import org.opensearch.client.Response; +import org.opensearch.ml.common.MLTaskState; +import org.opensearch.ml.common.utils.StringUtils; +import org.opensearch.ml.utils.TestHelper; + +import com.google.common.collect.ImmutableList; + +public class RestMLRemoteInferenceIT extends MLCommonsRestTestCase { + + private final String completionModelConnectorEntity = "{\n" + + "\"name\": \"OpenAI Connector\",\n" + + "\"description\": \"The connector to public OpenAI model service for GPT 3.5\",\n" + + "\"version\": 1,\n" + + "\"protocol\": \"http\",\n" + + "\"parameters\": {\n" + + " \"endpoint\": \"api.openai.com\",\n" + + " \"auth\": \"API_Key\",\n" + + " \"content_type\": \"application/json\",\n" + + " \"max_tokens\": 7,\n" + + " \"temperature\": 0,\n" + + " \"model\": \"text-davinci-003\"\n" + + " },\n" + + " \"credential\": {\n" + + " \"openAI_key\": \"" + System.getenv("OPENAI_KEY") + "\"\n" + + " },\n" + + " \"actions\": [\n" + + " {" + + " \"action_type\": \"predict\",\n" + + " \"method\": \"POST\",\n" + + " \"url\": \"https://${parameters.endpoint}/v1/completions\",\n" + + " \"headers\": {\n" + + " \"Authorization\": \"Bearer ${credential.openAI_key}\"\n" + + " },\n" + + " \"request_body\": \"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"prompt\\\": \\\"${parameters.prompt}\\\", \\\"max_tokens\\\": ${parameters.max_tokens}, \\\"temperature\\\": ${parameters.temperature} }\"\n" + + " }\n" + + " ]\n" + + "}"; + + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + @Before + private void setEncryptionMasterKey() throws IOException { + Response response = TestHelper + .makeRequest( + client(), + "PUT", + "_cluster/settings", + null, + "{\"persistent\":{\"plugins.ml_commons.encryption.master_key\":\"0000000000000011\"}}", + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "")) + ); + assertEquals(200, response.getStatusLine().getStatusCode()); + } + + public void testCreateConnector() throws IOException { + System.out.println(System.getenv()); + disableClusterConnectorAccessControl(); + setEncryptionMasterKey(); + Response response = createConnector(completionModelConnectorEntity); + Map responseMap = parseResponseToMap(response); + assertNotNull((String) responseMap.get("connector_id")); + } + + public void testGetConnector() throws IOException { + disableClusterConnectorAccessControl(); + setEncryptionMasterKey(); + Response response = createConnector(completionModelConnectorEntity); + Map responseMap = parseResponseToMap(response); + String connectorId = (String) responseMap.get("connector_id"); + response = TestHelper + .makeRequest( + client(), + "GET", + "/_plugins/_ml/connectors/" + connectorId, + null, + "", + null + ); + responseMap = parseResponseToMap(response); + assertEquals("OpenAI Connector", (String) responseMap.get("name")); + assertEquals("1", (String) responseMap.get("version")); + assertEquals("The connector to public OpenAI model service for GPT 3.5", (String) responseMap.get("description")); + assertEquals("http", (String) responseMap.get("protocol")); + } + + @Ignore + public void testDeleteConnector() throws IOException { + disableClusterConnectorAccessControl(); + setEncryptionMasterKey(); + Response response = createConnector(completionModelConnectorEntity); + Map responseMap = parseResponseToMap(response); + String connectorId = (String) responseMap.get("connector_id"); + response = TestHelper + .makeRequest( + client(), + "DELETE", + "/_plugins/_ml/connectors/" + connectorId, + null, + "", + null + ); + responseMap = parseResponseToMap(response); + assertEquals("deleted", (String) responseMap.get("result")); + } + + public void testSearchConnectors() throws IOException { + disableClusterConnectorAccessControl(); + setEncryptionMasterKey(); + createConnector(completionModelConnectorEntity); + String searchEntity = "{\n" + + " \"query\": {\n" + + " \"match_all\": {}\n" + + " },\n" + + " \"size\": 1000\n" + + "}"; + Response response = TestHelper + .makeRequest( + client(), + "GET", + "/_plugins/_ml/connectors/_search", + null, + TestHelper.toHttpEntity(searchEntity), + null + ); + Map responseMap = parseResponseToMap(response); + assertEquals((Double) 1.0, (Double) ((Map) ((Map) responseMap.get("hits")).get("total")).get("value")); + + } + + public void testRegisterRemoteModel() throws IOException, InterruptedException { + disableClusterConnectorAccessControl(); + setEncryptionMasterKey(); + Response response = createConnector(completionModelConnectorEntity); + Map responseMap = parseResponseToMap(response); + String connectorId = (String) responseMap.get("connector_id"); + response = registerRemoteModel("openAI-GPT-3.5 completions", connectorId); + responseMap = parseResponseToMap(response); + String taskId = (String) responseMap.get("task_id"); + waitForTask(taskId, MLTaskState.COMPLETED); + response = getTask(taskId); + responseMap = parseResponseToMap(response); + assertNotNull(responseMap.get("model_id")); + } + + public void testDeployRemoteModel() throws IOException, InterruptedException { + disableClusterConnectorAccessControl(); + setEncryptionMasterKey(); + Response response = createConnector(completionModelConnectorEntity); + Map responseMap = parseResponseToMap(response); + String connectorId = (String) responseMap.get("connector_id"); + response = registerRemoteModel("openAI-GPT-3.5 completions", connectorId); + responseMap = parseResponseToMap(response); + String taskId = (String) responseMap.get("task_id"); + waitForTask(taskId, MLTaskState.COMPLETED); + response = getTask(taskId); + responseMap = parseResponseToMap(response); + String modelId = (String) responseMap.get("model_id"); + response = deployRemoteModel(modelId); + responseMap = parseResponseToMap(response); + assertEquals("CREATED", (String) responseMap.get("status")); + taskId = (String) responseMap.get("task_id"); + waitForTask(taskId, MLTaskState.COMPLETED); + } + + public void testPredictRemoteModel() throws IOException, InterruptedException { + disableClusterConnectorAccessControl(); + setEncryptionMasterKey(); + Response response = createConnector(completionModelConnectorEntity); + Map responseMap = parseResponseToMap(response); + String connectorId = (String) responseMap.get("connector_id"); + response = registerRemoteModel("openAI-GPT-3.5 completions", connectorId); + responseMap = parseResponseToMap(response); + String taskId = (String) responseMap.get("task_id"); + waitForTask(taskId, MLTaskState.COMPLETED); + response = getTask(taskId); + responseMap = parseResponseToMap(response); + String modelId = (String) responseMap.get("model_id"); + response = deployRemoteModel(modelId); + responseMap = parseResponseToMap(response); + taskId = (String) responseMap.get("task_id"); + waitForTask(taskId, MLTaskState.COMPLETED); + String predictInput = "{\n" + + " \"parameters\": {\n" + + " \"prompt\": \"Say this is a test\"\n" + + " }\n" + + "}"; + response = predictRemoteModel(modelId, predictInput); + responseMap = parseResponseToMap(response); + List responseList = (List) responseMap.get("inference_results"); + responseMap = (Map) responseList.get(0); + responseList = (List) responseMap.get("output"); + responseMap = (Map) responseList.get(0); + responseMap = (Map) responseMap.get("dataAsMap"); + responseList = (List) responseMap.get("choices"); + responseMap = (Map) responseList.get(0); + assertEquals("\n\nThis is indeed a test", (String) responseMap.get("text")); + } + + public void testUndeployRemoteModel() throws IOException, InterruptedException { + disableClusterConnectorAccessControl(); + setEncryptionMasterKey(); + Response response = createConnector(completionModelConnectorEntity); + Map responseMap = parseResponseToMap(response); + String connectorId = (String) responseMap.get("connector_id"); + response = registerRemoteModel("openAI-GPT-3.5 completions", connectorId); + responseMap = parseResponseToMap(response); + String taskId = (String) responseMap.get("task_id"); + waitForTask(taskId, MLTaskState.COMPLETED); + response = getTask(taskId); + responseMap = parseResponseToMap(response); + String modelId = (String) responseMap.get("model_id"); + response = deployRemoteModel(modelId); + responseMap = parseResponseToMap(response); + taskId = (String) responseMap.get("task_id"); + waitForTask(taskId, MLTaskState.COMPLETED); + response = undeployRemoteModel(modelId); + responseMap = parseResponseToMap(response); + assertTrue(responseMap.toString().contains("undeployed")); + } + + @Ignore + public void testOpenAIChatCompletionModel() throws IOException, InterruptedException { + String entity = "{\n" + + " \"name\": \"OpenAI chat model Connector\",\n" + + " \"description\": \"The connector to public OpenAI model service for GPT 3.5\",\n" + + " \"version\": 1,\n" + + " \"protocol\": \"http\",\n" + + " \"parameters\": {\n" + + " \"endpoint\": \"api.openai.com\",\n" + + " \"auth\": \"API_Key\",\n" + + " \"content_type\": \"application/json\",\n" + + " \"max_tokens\": 7,\n" + + " \"temperature\": 0,\n" + + " \"model\": \"gpt-3.5-turbo\"\n" + + " },\n" + + " \"credential\": {\n" + + " \"openAI_key\": \"" + System.getenv("OPENAI_KEY") + "\"\n" + + " },\n" + + " \"actions\": [\n" + + " {\n" + + " \"action_type\": \"predict\",\n" + + " \"method\": \"POST\",\n" + + " \"url\": \"https://api.openai.com/v1/chat/completions\",\n" + + " \"headers\": { \n" + + " \"Authorization\": \"Bearer ${credential.openAI_key}\"\n" + + " },\n" + + " \"request_body\": \"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"messages\\\": ${parameters.messages} }\"\n" + + " }\n" + + " ]\n" + + "}"; + Response response = createConnector(entity); + Map responseMap = parseResponseToMap(response); + String connectorId = (String) responseMap.get("connector_id"); + response = registerRemoteModel("openAI-GPT-3.5 chat model", connectorId); + responseMap = parseResponseToMap(response); + String taskId = (String) responseMap.get("task_id"); + waitForTask(taskId, MLTaskState.COMPLETED); + response = getTask(taskId); + responseMap = parseResponseToMap(response); + String modelId = (String) responseMap.get("model_id"); + response = deployRemoteModel(modelId); + responseMap = parseResponseToMap(response); + taskId = (String) responseMap.get("task_id"); + waitForTask(taskId, MLTaskState.COMPLETED); + String predictInput = "{\n" + + " \"parameters\": {\n" + + " \"messages\": [{\"role\": \"user\", \"content\": \"Hello!\"}]\n" + + " }\n" + + "}"; + response = predictRemoteModel(modelId, predictInput); + responseMap = parseResponseToMap(response); + } + + public void testOpenAIEditsModel() throws IOException, InterruptedException { + String entity = "{\n" + + " \"name\": \"OpenAI Edit model Connector\",\n" + + " \"description\": \"The connector to public OpenAI edit model service\",\n" + + " \"version\": 1,\n" + + " \"protocol\": \"http\",\n" + + " \"parameters\": {\n" + + " \"endpoint\": \"api.openai.com\",\n" + + " \"auth\": \"API_Key\",\n" + + " \"content_type\": \"application/json\",\n" + + " \"model\": \"text-davinci-edit-001\"\n" + + " },\n" + + " \"credential\": {\n" + + " \"openAI_key\": \"" + System.getenv("OPENAI_KEY") + "\"\n" + + " },\n" + + " \"actions\": [\n" + + " {\n" + + " \"action_type\": \"predict\",\n" + + " \"method\": \"POST\",\n" + + " \"url\": \"https://api.openai.com/v1/edits\",\n" + + " \"headers\": { \n" + + " \"Authorization\": \"Bearer ${credential.openAI_key}\"\n" + + " },\n" + + " \"request_body\": \"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"input\\\": \\\"${parameters.input}\\\", \\\"instruction\\\": \\\"${parameters.instruction}\\\" }\"\n" + + " }\n" + + " ]\n" + + "}"; + Response response = createConnector(entity); + Map responseMap = parseResponseToMap(response); + String connectorId = (String) responseMap.get("connector_id"); + response = registerRemoteModel("openAI-GPT-3.5 edit model", connectorId); + responseMap = parseResponseToMap(response); + String taskId = (String) responseMap.get("task_id"); + waitForTask(taskId, MLTaskState.COMPLETED); + response = getTask(taskId); + responseMap = parseResponseToMap(response); + String modelId = (String) responseMap.get("model_id"); + response = deployRemoteModel(modelId); + responseMap = parseResponseToMap(response); + taskId = (String) responseMap.get("task_id"); + waitForTask(taskId, MLTaskState.COMPLETED); + String predictInput = "{\n" + + " \"parameters\": {\n" + + " \"input\": \"What day of the wek is it?\",\n" + + " \"instruction\": \"Fix the spelling mistakes\"\n" + + " }\n" + + "}"; + response = predictRemoteModel(modelId, predictInput); + responseMap = parseResponseToMap(response); + List responseList = (List) responseMap.get("inference_results"); + responseMap = (Map) responseList.get(0); + responseList = (List) responseMap.get("output"); + responseMap = (Map) responseList.get(0); + responseMap = (Map) responseMap.get("dataAsMap"); + responseList = (List) responseMap.get("choices"); + responseMap = (Map) responseList.get(0); + assertTrue(((String) responseMap.get("text")).contains("What day of the week is it?")); + } + + public void testOpenAIModerationsModel() throws IOException, InterruptedException { + String entity = "{\n" + + " \"name\": \"OpenAI moderations model Connector\",\n" + + " \"description\": \"The connector to public OpenAI moderations model service\",\n" + + " \"version\": 1,\n" + + " \"protocol\": \"http\",\n" + + " \"parameters\": {\n" + + " \"endpoint\": \"api.openai.com\",\n" + + " \"auth\": \"API_Key\",\n" + + " \"content_type\": \"application/json\",\n" + + " \"model\": \"moderations\"\n" + + " },\n" + + " \"credential\": {\n" + + " \"openAI_key\": \"" + System.getenv("OPENAI_KEY") + "\"\n" + + " },\n" + + " \"actions\": [\n" + + " {\n" + + " \"action_type\": \"predict\",\n" + + " \"method\": \"POST\",\n" + + " \"url\": \"https://api.openai.com/v1/moderations\",\n" + + " \"headers\": { \n" + + " \"Authorization\": \"Bearer ${credential.openAI_key}\"\n" + + " },\n" + + " \"request_body\": \"{ \\\"input\\\": \\\"${parameters.input}\\\" }\"\n" + + " }\n" + + " ]\n" + + "}"; + Response response = createConnector(entity); + Map responseMap = parseResponseToMap(response); + String connectorId = (String) responseMap.get("connector_id"); + response = registerRemoteModel("openAI-GPT-3.5 moderations model", connectorId); + responseMap = parseResponseToMap(response); + String taskId = (String) responseMap.get("task_id"); + waitForTask(taskId, MLTaskState.COMPLETED); + response = getTask(taskId); + responseMap = parseResponseToMap(response); + String modelId = (String) responseMap.get("model_id"); + response = deployRemoteModel(modelId); + responseMap = parseResponseToMap(response); + taskId = (String) responseMap.get("task_id"); + waitForTask(taskId, MLTaskState.COMPLETED); + String predictInput = "{\n" + + " \"parameters\": {\n" + + " \"input\": \"I want to kill them.\"\n" + + " }\n" + + "}"; + response = predictRemoteModel(modelId, predictInput); + responseMap = parseResponseToMap(response); + List responseList = (List) responseMap.get("inference_results"); + responseMap = (Map) responseList.get(0); + responseList = (List) responseMap.get("output"); + responseMap = (Map) responseList.get(0); + responseMap = (Map) responseMap.get("dataAsMap"); + responseList = (List) responseMap.get("results"); + responseMap = (Map) responseList.get(0); + assertTrue((Boolean) responseMap.get("flagged")); + responseMap = (Map) responseMap.get("categories"); + assertTrue((Boolean) responseMap.get("violence")); + } + + @Ignore + public void testOpenAITextEmbeddingModel() throws IOException, InterruptedException { + String entity = "{\n" + + " \"name\": \"OpenAI text embedding model Connector\",\n" + + " \"description\": \"The connector to public OpenAI text embedding model service\",\n" + + " \"version\": 1,\n" + + " \"protocol\": \"http\",\n" + + " \"parameters\": {\n" + + " \"endpoint\": \"api.openai.com\",\n" + + " \"auth\": \"API_Key\",\n" + + " \"content_type\": \"application/json\",\n" + + " \"model\": \"text-embedding-ada-002\"\n" + + " },\n" + + " \"credential\": {\n" + + " \"openAI_key\": \"" + System.getenv("OPENAI_KEY") + "\"\n" + + " },\n" + + " \"actions\": [\n" + + " {\n" + + " \"action_type\": \"predict\",\n" + + " \"method\": \"POST\",\n" + + " \"url\": \"https://api.openai.com/v1/embeddings\",\n" + + " \"headers\": { \n" + + " \"Authorization\": \"Bearer ${credential.openAI_key}\"\n" + + " },\n" + + " \"request_body\": \"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"input\\\": \\\"${parameters.input}\\\" }\",\n" + + " \"pre_process_function\": \"text_docs_to_openai_embedding_input\",\n" + + " \"post_process_function\": \"openai_embedding\"\n" + + " }\n" + + " ]\n" + + "}"; + Response response = createConnector(entity); + Map responseMap = parseResponseToMap(response); + String connectorId = (String) responseMap.get("connector_id"); + response = registerRemoteModel("openAI text embedding model", connectorId); + responseMap = parseResponseToMap(response); + String taskId = (String) responseMap.get("task_id"); + waitForTask(taskId, MLTaskState.COMPLETED); + response = getTask(taskId); + responseMap = parseResponseToMap(response); + String modelId = (String) responseMap.get("model_id"); + response = deployRemoteModel(modelId); + responseMap = parseResponseToMap(response); + taskId = (String) responseMap.get("task_id"); + waitForTask(taskId, MLTaskState.COMPLETED); + String predictInput = "{\n" + + " \"parameters\": {\n" + + " \"input\": \"The food was delicious\"\n" + + " }\n" + + "}"; + response = predictRemoteModel(modelId, predictInput); + responseMap = parseResponseToMap(response); + List responseList = (List) responseMap.get("inference_results"); + responseMap = (Map) responseList.get(0); + responseList = (List) responseMap.get("output"); + responseMap = (Map) responseList.get(0); + responseMap = (Map) responseMap.get("dataAsMap"); + responseList = (List) responseMap.get("data"); + responseMap = (Map) responseList.get(0); + assertFalse(((List) responseMap.get("embedding")).isEmpty()); + } + + public void testCohereGenerateTextModel() throws IOException, InterruptedException { + String entity = "{\n" + + " \"name\": \"Cohere generate text model Connector\",\n" + + " \"description\": \"The connector to public Cohere generate text model service\",\n" + + " \"version\": 1,\n" + + " \"protocol\": \"http\",\n" + + " \"parameters\": {\n" + + " \"endpoint\": \"api.cohere.ai\",\n" + + " \"auth\": \"API_Key\",\n" + + " \"content_type\": \"application/json\",\n" + + " \"max_tokens\": \"20\"\n" + + " },\n" + + " \"credential\": {\n" + + " \"cohere_key\": \"" + System.getenv("COHERE_KEY") + "\"\n" + + " },\n" + + " \"actions\": [\n" + + " {\n" + + " \"action_type\": \"predict\",\n" + + " \"method\": \"POST\",\n" + + " \"url\": \"https://${parameters.endpoint}/v1/generate\",\n" + + " \"headers\": { \n" + + " \"Authorization\": \"Bearer ${credential.cohere_key}\"\n" + + " },\n" + + " \"request_body\": \"{ \\\"max_tokens\\\": ${parameters.max_tokens}, \\\"return_likelihoods\\\": \\\"NONE\\\", \\\"truncate\\\": \\\"END\\\", \\\"prompt\\\": \\\"${parameters.prompt}\\\" }\"\n" + + " }\n" + + " ]\n" + + "}"; + Response response = createConnector(entity); + Map responseMap = parseResponseToMap(response); + String connectorId = (String) responseMap.get("connector_id"); + response = registerRemoteModel("cohere generate text model", connectorId); + responseMap = parseResponseToMap(response); + String taskId = (String) responseMap.get("task_id"); + waitForTask(taskId, MLTaskState.COMPLETED); + response = getTask(taskId); + responseMap = parseResponseToMap(response); + String modelId = (String) responseMap.get("model_id"); + response = deployRemoteModel(modelId); + responseMap = parseResponseToMap(response); + taskId = (String) responseMap.get("task_id"); + waitForTask(taskId, MLTaskState.COMPLETED); + String predictInput = "{\n" + + " \"parameters\": {\n" + + " \"prompt\": \"Once upon a time in a magical land called\",\n" + + " \"max_tokens\": 40\n" + + " }\n" + + "}"; + response = predictRemoteModel(modelId, predictInput); + responseMap = parseResponseToMap(response); + List responseList = (List) responseMap.get("inference_results"); + responseMap = (Map) responseList.get(0); + responseList = (List) responseMap.get("output"); + responseMap = (Map) responseList.get(0); + responseMap = (Map) responseMap.get("dataAsMap"); + responseList = (List) responseMap.get("generations"); + responseMap = (Map) responseList.get(0); + assertFalse(((String) responseMap.get("text")).isEmpty()); + } + + @Ignore + public void testCohereClassifyModel() throws IOException, InterruptedException { + String entity = "{\n" + + " \"name\": \"Cohere classify model Connector\",\n" + + " \"description\": \"The connector to public Cohere classify model service\",\n" + + " \"version\": 1,\n" + + " \"protocol\": \"http\",\n" + + " \"parameters\": {\n" + + " \"endpoint\": \"api.cohere.ai\",\n" + + " \"auth\": \"API_Key\",\n" + + " \"content_type\": \"application/json\",\n" + + " \"max_tokens\": \"20\"\n" + + " },\n" + + " \"credential\": {\n" + + " \"cohere_key\": \"" + System.getenv("COHERE_KEY") + "\"\n" + + " },\n" + + " \"actions\": [\n" + + " {\n" + + " \"action_type\": \"predict\",\n" + + " \"method\": \"POST\",\n" + + " \"url\": \"https://${parameters.endpoint}/v1/classify\",\n" + + " \"headers\": { \n" + + " \"Authorization\": \"Bearer ${credential.cohere_key}\"\n" + + " },\n" + + " \"request_body\": \"{ \\\"inputs\\\": ${parameters.inputs}, \\\"examples\\\": ${parameters.examples}, \\\"truncate\\\": \\\"END\\\" }\"\n" + + " }\n" + + " ]\n" + + "}"; + Response response = createConnector(entity); + Map responseMap = parseResponseToMap(response); + String connectorId = (String) responseMap.get("connector_id"); + response = registerRemoteModel("cohere classify model", connectorId); + responseMap = parseResponseToMap(response); + String taskId = (String) responseMap.get("task_id"); + waitForTask(taskId, MLTaskState.COMPLETED); + response = getTask(taskId); + responseMap = parseResponseToMap(response); + String modelId = (String) responseMap.get("model_id"); + response = deployRemoteModel(modelId); + responseMap = parseResponseToMap(response); + taskId = (String) responseMap.get("task_id"); + waitForTask(taskId, MLTaskState.COMPLETED); + String predictInput = "{\n" + + " \"parameters\": {\n" + + " \"inputs\": [\n" + + " \"Confirm your email address\",\n" + + " \"hey i need u to send some $\"\n" + + " ],\n" + + " \"examples\": [\n" + + " {\n" + + " \"text\": \"Dermatologists don't like her!\",\n" + + " \"label\": \"Spam\"\n" + + " },\n" + + " {\n" + + " \"text\": \"Hello, open to this?\",\n" + + " \"label\": \"Spam\"\n" + + " },\n" + + " {\n" + + " \"text\": \"I need help please wire me $1000 right now\",\n" + + " \"label\": \"Spam\"\n" + + " },\n" + + " {\n" + + " \"text\": \"Nice to know you ;)\",\n" + + " \"label\": \"Spam\"\n" + + " },\n" + + " {\n" + + " \"text\": \"Please help me?\",\n" + + " \"label\": \"Spam\"\n" + + " },\n" + + " {\n" + + " \"text\": \"Your parcel will be delivered today\",\n" + + " \"label\": \"Not spam\"\n" + + " },\n" + + " {\n" + + " \"text\": \"Review changes to our Terms and Conditions\",\n" + + " \"label\": \"Not spam\"\n" + + " },\n" + + " {\n" + + " \"text\": \"Weekly sync notes\",\n" + + " \"label\": \"Not spam\"\n" + + " },\n" + + " {\n" + + " \"text\": \"Re: Follow up from todays meeting\",\n" + + " \"label\": \"Not spam\"\n" + + " },\n" + + " {\n" + + " \"text\": \"Pre-read for tomorrow\",\n" + + " \"label\": \"Not spam\"\n" + + " }\n" + + " ]\n" + + " }\n" + + "}"; + + response = predictRemoteModel(modelId, predictInput); + responseMap = parseResponseToMap(response); + } + + private Response createConnector(String input) throws IOException { + return TestHelper + .makeRequest( + client(), + "POST", + "/_plugins/_ml/connectors/_create", + null, + TestHelper.toHttpEntity(input), + null + ); + } + + private Response registerRemoteModel(String name, String connectorId) throws IOException { + String registerModelGroupEntity = "{\n" + + " \"name\": \"remote_model_group\",\n" + + " \"description\": \"This is an example description\"\n" + + "}"; + Response response = TestHelper + .makeRequest( + client(), + "POST", + "/_plugins/_ml/model_groups/_register", + null, + TestHelper.toHttpEntity(registerModelGroupEntity), + null + ); + Map responseMap = parseResponseToMap(response); + assertEquals((String) responseMap.get("status"), "CREATED"); + String modelGroupId = (String) responseMap.get("model_group_id"); + + String registerModelEntity = "{\n" + + " \"name\": \"" + name + "\",\n" + + " \"function_name\": \"remote\",\n" + + " \"model_group_id\": \"" + modelGroupId + "\",\n" + + " \"version\": \"1.0.0\",\n" + + " \"description\": \"test model\",\n" + + " \"connector_id\": \"" + connectorId + "\"\n" + + "}"; + return TestHelper + .makeRequest( + client(), + "POST", + "/_plugins/_ml/models/_register", + null, + TestHelper.toHttpEntity(registerModelEntity), + null + ); + } + + private Response deployRemoteModel(String modelId) throws IOException { + return TestHelper + .makeRequest( + client(), + "POST", + "/_plugins/_ml/models/" + modelId + "/_deploy", + null, + "", + null + ); + } + + private Response predictRemoteModel(String modelId, String input) throws IOException { + return TestHelper + .makeRequest( + client(), + "POST", + "/_plugins/_ml/models/" + modelId + "/_predict", + null, + input, + null + ); + } + + private Response undeployRemoteModel(String modelId) throws IOException { + String undeployEntity = "{\n" + + " \"SYqCMdsFTumUwoHZcsgiUg\": {\n" + + " \"stats\": {\n" + + " \"" + modelId + "\": \"undeployed\"\n" + + " }\n" + + " }\n" + + "}"; + return TestHelper + .makeRequest( + client(), + "POST", + "/_plugins/_ml/models/" + modelId + "/_undeploy", + null, + undeployEntity, + null + ); + } + + private Map parseResponseToMap(Response response) throws IOException { + HttpEntity entity = response.getEntity(); + assertNotNull(response); + String entityString = TestHelper.httpEntityToString(entity); + return gson.fromJson(entityString, Map.class); + } + + private void disableClusterConnectorAccessControl() throws IOException { + Response response = TestHelper + .makeRequest( + client(), + "PUT", + "_cluster/settings", + null, + "{\"persistent\":{\"plugins.ml_commons.connector_access_control_enabled\":false}}", + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "")) + ); + assertEquals(200, response.getStatusLine().getStatusCode()); + } + + private Response getTask(String taskId) throws IOException { + return TestHelper.makeRequest(client(), "GET", "/_plugins/_ml/tasks/" + taskId, null, "", null); + } + +} From fe21a5f1cf7a60290bd49107b4e84b213c72b8ea Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Tue, 11 Jul 2023 08:07:33 +0000 Subject: [PATCH 02/15] Change edits model test Signed-off-by: Ryan Bogan --- .../ml/rest/RestMLRemoteInferenceIT.java | 495 ++++++++---------- 1 file changed, 222 insertions(+), 273 deletions(-) diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java index 20c6cc3482..33665d5fe8 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java @@ -12,14 +12,12 @@ import org.apache.http.HttpEntity; import org.apache.http.HttpHeaders; import org.apache.http.message.BasicHeader; -import org.junit.After; import org.junit.Before; import org.junit.Ignore; import org.junit.Rule; import org.junit.rules.ExpectedException; import org.opensearch.client.Response; import org.opensearch.ml.common.MLTaskState; -import org.opensearch.ml.common.utils.StringUtils; import org.opensearch.ml.utils.TestHelper; import com.google.common.collect.ImmutableList; @@ -27,34 +25,36 @@ public class RestMLRemoteInferenceIT extends MLCommonsRestTestCase { private final String completionModelConnectorEntity = "{\n" - + "\"name\": \"OpenAI Connector\",\n" - + "\"description\": \"The connector to public OpenAI model service for GPT 3.5\",\n" - + "\"version\": 1,\n" - + "\"protocol\": \"http\",\n" - + "\"parameters\": {\n" - + " \"endpoint\": \"api.openai.com\",\n" - + " \"auth\": \"API_Key\",\n" - + " \"content_type\": \"application/json\",\n" - + " \"max_tokens\": 7,\n" - + " \"temperature\": 0,\n" - + " \"model\": \"text-davinci-003\"\n" - + " },\n" - + " \"credential\": {\n" - + " \"openAI_key\": \"" + System.getenv("OPENAI_KEY") + "\"\n" - + " },\n" - + " \"actions\": [\n" - + " {" - + " \"action_type\": \"predict\",\n" - + " \"method\": \"POST\",\n" - + " \"url\": \"https://${parameters.endpoint}/v1/completions\",\n" - + " \"headers\": {\n" - + " \"Authorization\": \"Bearer ${credential.openAI_key}\"\n" - + " },\n" - + " \"request_body\": \"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"prompt\\\": \\\"${parameters.prompt}\\\", \\\"max_tokens\\\": ${parameters.max_tokens}, \\\"temperature\\\": ${parameters.temperature} }\"\n" - + " }\n" - + " ]\n" - + "}"; - + + "\"name\": \"OpenAI Connector\",\n" + + "\"description\": \"The connector to public OpenAI model service for GPT 3.5\",\n" + + "\"version\": 1,\n" + + "\"protocol\": \"http\",\n" + + "\"parameters\": {\n" + + " \"endpoint\": \"api.openai.com\",\n" + + " \"auth\": \"API_Key\",\n" + + " \"content_type\": \"application/json\",\n" + + " \"max_tokens\": 7,\n" + + " \"temperature\": 0,\n" + + " \"model\": \"text-davinci-003\"\n" + + " },\n" + + " \"credential\": {\n" + + " \"openAI_key\": \"" + + System.getenv("OPENAI_KEY") + + "\"\n" + + " },\n" + + " \"actions\": [\n" + + " {" + + " \"action_type\": \"predict\",\n" + + " \"method\": \"POST\",\n" + + " \"url\": \"https://${parameters.endpoint}/v1/completions\",\n" + + " \"headers\": {\n" + + " \"Authorization\": \"Bearer ${credential.openAI_key}\"\n" + + " },\n" + + " \"request_body\": \"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"prompt\\\": \\\"${parameters.prompt}\\\", \\\"max_tokens\\\": ${parameters.max_tokens}, \\\"temperature\\\": ${parameters.temperature} }\"\n" + + " }\n" + + " ]\n" + + "}"; + @Rule public ExpectedException exceptionRule = ExpectedException.none(); @@ -71,7 +71,7 @@ private void setEncryptionMasterKey() throws IOException { ); assertEquals(200, response.getStatusLine().getStatusCode()); } - + public void testCreateConnector() throws IOException { System.out.println(System.getenv()); disableClusterConnectorAccessControl(); @@ -87,15 +87,7 @@ public void testGetConnector() throws IOException { Response response = createConnector(completionModelConnectorEntity); Map responseMap = parseResponseToMap(response); String connectorId = (String) responseMap.get("connector_id"); - response = TestHelper - .makeRequest( - client(), - "GET", - "/_plugins/_ml/connectors/" + connectorId, - null, - "", - null - ); + response = TestHelper.makeRequest(client(), "GET", "/_plugins/_ml/connectors/" + connectorId, null, "", null); responseMap = parseResponseToMap(response); assertEquals("OpenAI Connector", (String) responseMap.get("name")); assertEquals("1", (String) responseMap.get("version")); @@ -110,15 +102,7 @@ public void testDeleteConnector() throws IOException { Response response = createConnector(completionModelConnectorEntity); Map responseMap = parseResponseToMap(response); String connectorId = (String) responseMap.get("connector_id"); - response = TestHelper - .makeRequest( - client(), - "DELETE", - "/_plugins/_ml/connectors/" + connectorId, - null, - "", - null - ); + response = TestHelper.makeRequest(client(), "DELETE", "/_plugins/_ml/connectors/" + connectorId, null, "", null); responseMap = parseResponseToMap(response); assertEquals("deleted", (String) responseMap.get("result")); } @@ -127,24 +111,12 @@ public void testSearchConnectors() throws IOException { disableClusterConnectorAccessControl(); setEncryptionMasterKey(); createConnector(completionModelConnectorEntity); - String searchEntity = "{\n" - + " \"query\": {\n" - + " \"match_all\": {}\n" - + " },\n" - + " \"size\": 1000\n" - + "}"; + String searchEntity = "{\n" + " \"query\": {\n" + " \"match_all\": {}\n" + " },\n" + " \"size\": 1000\n" + "}"; Response response = TestHelper - .makeRequest( - client(), - "GET", - "/_plugins/_ml/connectors/_search", - null, - TestHelper.toHttpEntity(searchEntity), - null - ); + .makeRequest(client(), "GET", "/_plugins/_ml/connectors/_search", null, TestHelper.toHttpEntity(searchEntity), null); Map responseMap = parseResponseToMap(response); assertEquals((Double) 1.0, (Double) ((Map) ((Map) responseMap.get("hits")).get("total")).get("value")); - + } public void testRegisterRemoteModel() throws IOException, InterruptedException { @@ -199,11 +171,7 @@ public void testPredictRemoteModel() throws IOException, InterruptedException { responseMap = parseResponseToMap(response); taskId = (String) responseMap.get("task_id"); waitForTask(taskId, MLTaskState.COMPLETED); - String predictInput = "{\n" - + " \"parameters\": {\n" - + " \"prompt\": \"Say this is a test\"\n" - + " }\n" - + "}"; + String predictInput = "{\n" + " \"parameters\": {\n" + " \"prompt\": \"Say this is a test\"\n" + " }\n" + "}"; response = predictRemoteModel(modelId, predictInput); responseMap = parseResponseToMap(response); List responseList = (List) responseMap.get("inference_results"); @@ -254,17 +222,19 @@ public void testOpenAIChatCompletionModel() throws IOException, InterruptedExcep + " \"model\": \"gpt-3.5-turbo\"\n" + " },\n" + " \"credential\": {\n" - + " \"openAI_key\": \"" + System.getenv("OPENAI_KEY") + "\"\n" + + " \"openAI_key\": \"" + + System.getenv("OPENAI_KEY") + + "\"\n" + " },\n" + " \"actions\": [\n" + " {\n" - + " \"action_type\": \"predict\",\n" - + " \"method\": \"POST\",\n" - + " \"url\": \"https://api.openai.com/v1/chat/completions\",\n" - + " \"headers\": { \n" - + " \"Authorization\": \"Bearer ${credential.openAI_key}\"\n" - + " },\n" - + " \"request_body\": \"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"messages\\\": ${parameters.messages} }\"\n" + + " \"action_type\": \"predict\",\n" + + " \"method\": \"POST\",\n" + + " \"url\": \"https://api.openai.com/v1/chat/completions\",\n" + + " \"headers\": { \n" + + " \"Authorization\": \"Bearer ${credential.openAI_key}\"\n" + + " },\n" + + " \"request_body\": \"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"messages\\\": ${parameters.messages} }\"\n" + " }\n" + " ]\n" + "}"; @@ -293,31 +263,33 @@ public void testOpenAIChatCompletionModel() throws IOException, InterruptedExcep public void testOpenAIEditsModel() throws IOException, InterruptedException { String entity = "{\n" - + " \"name\": \"OpenAI Edit model Connector\",\n" - + " \"description\": \"The connector to public OpenAI edit model service\",\n" - + " \"version\": 1,\n" - + " \"protocol\": \"http\",\n" - + " \"parameters\": {\n" - + " \"endpoint\": \"api.openai.com\",\n" - + " \"auth\": \"API_Key\",\n" - + " \"content_type\": \"application/json\",\n" - + " \"model\": \"text-davinci-edit-001\"\n" - + " },\n" - + " \"credential\": {\n" - + " \"openAI_key\": \"" + System.getenv("OPENAI_KEY") + "\"\n" - + " },\n" - + " \"actions\": [\n" - + " {\n" - + " \"action_type\": \"predict\",\n" - + " \"method\": \"POST\",\n" - + " \"url\": \"https://api.openai.com/v1/edits\",\n" - + " \"headers\": { \n" - + " \"Authorization\": \"Bearer ${credential.openAI_key}\"\n" - + " },\n" - + " \"request_body\": \"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"input\\\": \\\"${parameters.input}\\\", \\\"instruction\\\": \\\"${parameters.instruction}\\\" }\"\n" - + " }\n" - + " ]\n" - + "}"; + + " \"name\": \"OpenAI Edit model Connector\",\n" + + " \"description\": \"The connector to public OpenAI edit model service\",\n" + + " \"version\": 1,\n" + + " \"protocol\": \"http\",\n" + + " \"parameters\": {\n" + + " \"endpoint\": \"api.openai.com\",\n" + + " \"auth\": \"API_Key\",\n" + + " \"content_type\": \"application/json\",\n" + + " \"model\": \"text-davinci-edit-001\"\n" + + " },\n" + + " \"credential\": {\n" + + " \"openAI_key\": \"" + + System.getenv("OPENAI_KEY") + + "\"\n" + + " },\n" + + " \"actions\": [\n" + + " {\n" + + " \"action_type\": \"predict\",\n" + + " \"method\": \"POST\",\n" + + " \"url\": \"https://api.openai.com/v1/edits\",\n" + + " \"headers\": { \n" + + " \"Authorization\": \"Bearer ${credential.openAI_key}\"\n" + + " },\n" + + " \"request_body\": \"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"input\\\": \\\"${parameters.input}\\\", \\\"instruction\\\": \\\"${parameters.instruction}\\\" }\"\n" + + " }\n" + + " ]\n" + + "}"; Response response = createConnector(entity); Map responseMap = parseResponseToMap(response); String connectorId = (String) responseMap.get("connector_id"); @@ -338,45 +310,55 @@ public void testOpenAIEditsModel() throws IOException, InterruptedException { + " \"instruction\": \"Fix the spelling mistakes\"\n" + " }\n" + "}"; - response = predictRemoteModel(modelId, predictInput); - responseMap = parseResponseToMap(response); - List responseList = (List) responseMap.get("inference_results"); - responseMap = (Map) responseList.get(0); - responseList = (List) responseMap.get("output"); - responseMap = (Map) responseList.get(0); - responseMap = (Map) responseMap.get("dataAsMap"); - responseList = (List) responseMap.get("choices"); - responseMap = (Map) responseList.get(0); - assertTrue(((String) responseMap.get("text")).contains("What day of the week is it?")); + // Currently, the OpenAI Edits model does not work 100% of the time. This loop tries 3 times to see if the problem is with the model + // or with OpenSearch. + boolean editsModelSuccess = false; + for (int i = 0; i < 3; i++) { + response = predictRemoteModel(modelId, predictInput); + responseMap = parseResponseToMap(response); + List responseList = (List) responseMap.get("inference_results"); + responseMap = (Map) responseList.get(0); + responseList = (List) responseMap.get("output"); + responseMap = (Map) responseList.get(0); + responseMap = (Map) responseMap.get("dataAsMap"); + responseList = (List) responseMap.get("choices"); + responseMap = (Map) responseList.get(0); + if (((String) responseMap.get("text")).contains("What day of the week is it?")) { + editsModelSuccess = true; + } + } + assertTrue(editsModelSuccess); } public void testOpenAIModerationsModel() throws IOException, InterruptedException { String entity = "{\n" - + " \"name\": \"OpenAI moderations model Connector\",\n" - + " \"description\": \"The connector to public OpenAI moderations model service\",\n" - + " \"version\": 1,\n" - + " \"protocol\": \"http\",\n" - + " \"parameters\": {\n" - + " \"endpoint\": \"api.openai.com\",\n" - + " \"auth\": \"API_Key\",\n" - + " \"content_type\": \"application/json\",\n" - + " \"model\": \"moderations\"\n" - + " },\n" - + " \"credential\": {\n" - + " \"openAI_key\": \"" + System.getenv("OPENAI_KEY") + "\"\n" - + " },\n" - + " \"actions\": [\n" - + " {\n" - + " \"action_type\": \"predict\",\n" - + " \"method\": \"POST\",\n" - + " \"url\": \"https://api.openai.com/v1/moderations\",\n" - + " \"headers\": { \n" - + " \"Authorization\": \"Bearer ${credential.openAI_key}\"\n" - + " },\n" - + " \"request_body\": \"{ \\\"input\\\": \\\"${parameters.input}\\\" }\"\n" - + " }\n" - + " ]\n" - + "}"; + + " \"name\": \"OpenAI moderations model Connector\",\n" + + " \"description\": \"The connector to public OpenAI moderations model service\",\n" + + " \"version\": 1,\n" + + " \"protocol\": \"http\",\n" + + " \"parameters\": {\n" + + " \"endpoint\": \"api.openai.com\",\n" + + " \"auth\": \"API_Key\",\n" + + " \"content_type\": \"application/json\",\n" + + " \"model\": \"moderations\"\n" + + " },\n" + + " \"credential\": {\n" + + " \"openAI_key\": \"" + + System.getenv("OPENAI_KEY") + + "\"\n" + + " },\n" + + " \"actions\": [\n" + + " {\n" + + " \"action_type\": \"predict\",\n" + + " \"method\": \"POST\",\n" + + " \"url\": \"https://api.openai.com/v1/moderations\",\n" + + " \"headers\": { \n" + + " \"Authorization\": \"Bearer ${credential.openAI_key}\"\n" + + " },\n" + + " \"request_body\": \"{ \\\"input\\\": \\\"${parameters.input}\\\" }\"\n" + + " }\n" + + " ]\n" + + "}"; Response response = createConnector(entity); Map responseMap = parseResponseToMap(response); String connectorId = (String) responseMap.get("connector_id"); @@ -391,11 +373,7 @@ public void testOpenAIModerationsModel() throws IOException, InterruptedExceptio responseMap = parseResponseToMap(response); taskId = (String) responseMap.get("task_id"); waitForTask(taskId, MLTaskState.COMPLETED); - String predictInput = "{\n" - + " \"parameters\": {\n" - + " \"input\": \"I want to kill them.\"\n" - + " }\n" - + "}"; + String predictInput = "{\n" + " \"parameters\": {\n" + " \"input\": \"I want to kill them.\"\n" + " }\n" + "}"; response = predictRemoteModel(modelId, predictInput); responseMap = parseResponseToMap(response); List responseList = (List) responseMap.get("inference_results"); @@ -413,33 +391,35 @@ public void testOpenAIModerationsModel() throws IOException, InterruptedExceptio @Ignore public void testOpenAITextEmbeddingModel() throws IOException, InterruptedException { String entity = "{\n" - + " \"name\": \"OpenAI text embedding model Connector\",\n" - + " \"description\": \"The connector to public OpenAI text embedding model service\",\n" - + " \"version\": 1,\n" - + " \"protocol\": \"http\",\n" - + " \"parameters\": {\n" - + " \"endpoint\": \"api.openai.com\",\n" - + " \"auth\": \"API_Key\",\n" - + " \"content_type\": \"application/json\",\n" - + " \"model\": \"text-embedding-ada-002\"\n" - + " },\n" - + " \"credential\": {\n" - + " \"openAI_key\": \"" + System.getenv("OPENAI_KEY") + "\"\n" - + " },\n" - + " \"actions\": [\n" - + " {\n" - + " \"action_type\": \"predict\",\n" - + " \"method\": \"POST\",\n" - + " \"url\": \"https://api.openai.com/v1/embeddings\",\n" - + " \"headers\": { \n" - + " \"Authorization\": \"Bearer ${credential.openAI_key}\"\n" - + " },\n" - + " \"request_body\": \"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"input\\\": \\\"${parameters.input}\\\" }\",\n" - + " \"pre_process_function\": \"text_docs_to_openai_embedding_input\",\n" - + " \"post_process_function\": \"openai_embedding\"\n" - + " }\n" - + " ]\n" - + "}"; + + " \"name\": \"OpenAI text embedding model Connector\",\n" + + " \"description\": \"The connector to public OpenAI text embedding model service\",\n" + + " \"version\": 1,\n" + + " \"protocol\": \"http\",\n" + + " \"parameters\": {\n" + + " \"endpoint\": \"api.openai.com\",\n" + + " \"auth\": \"API_Key\",\n" + + " \"content_type\": \"application/json\",\n" + + " \"model\": \"text-embedding-ada-002\"\n" + + " },\n" + + " \"credential\": {\n" + + " \"openAI_key\": \"" + + System.getenv("OPENAI_KEY") + + "\"\n" + + " },\n" + + " \"actions\": [\n" + + " {\n" + + " \"action_type\": \"predict\",\n" + + " \"method\": \"POST\",\n" + + " \"url\": \"https://api.openai.com/v1/embeddings\",\n" + + " \"headers\": { \n" + + " \"Authorization\": \"Bearer ${credential.openAI_key}\"\n" + + " },\n" + + " \"request_body\": \"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"input\\\": \\\"${parameters.input}\\\" }\",\n" + + " \"pre_process_function\": \"text_docs_to_openai_embedding_input\",\n" + + " \"post_process_function\": \"openai_embedding\"\n" + + " }\n" + + " ]\n" + + "}"; Response response = createConnector(entity); Map responseMap = parseResponseToMap(response); String connectorId = (String) responseMap.get("connector_id"); @@ -454,11 +434,7 @@ public void testOpenAITextEmbeddingModel() throws IOException, InterruptedExcept responseMap = parseResponseToMap(response); taskId = (String) responseMap.get("task_id"); waitForTask(taskId, MLTaskState.COMPLETED); - String predictInput = "{\n" - + " \"parameters\": {\n" - + " \"input\": \"The food was delicious\"\n" - + " }\n" - + "}"; + String predictInput = "{\n" + " \"parameters\": {\n" + " \"input\": \"The food was delicious\"\n" + " }\n" + "}"; response = predictRemoteModel(modelId, predictInput); responseMap = parseResponseToMap(response); List responseList = (List) responseMap.get("inference_results"); @@ -473,31 +449,33 @@ public void testOpenAITextEmbeddingModel() throws IOException, InterruptedExcept public void testCohereGenerateTextModel() throws IOException, InterruptedException { String entity = "{\n" - + " \"name\": \"Cohere generate text model Connector\",\n" - + " \"description\": \"The connector to public Cohere generate text model service\",\n" - + " \"version\": 1,\n" - + " \"protocol\": \"http\",\n" - + " \"parameters\": {\n" - + " \"endpoint\": \"api.cohere.ai\",\n" - + " \"auth\": \"API_Key\",\n" - + " \"content_type\": \"application/json\",\n" - + " \"max_tokens\": \"20\"\n" - + " },\n" - + " \"credential\": {\n" - + " \"cohere_key\": \"" + System.getenv("COHERE_KEY") + "\"\n" - + " },\n" - + " \"actions\": [\n" - + " {\n" - + " \"action_type\": \"predict\",\n" - + " \"method\": \"POST\",\n" - + " \"url\": \"https://${parameters.endpoint}/v1/generate\",\n" - + " \"headers\": { \n" - + " \"Authorization\": \"Bearer ${credential.cohere_key}\"\n" - + " },\n" - + " \"request_body\": \"{ \\\"max_tokens\\\": ${parameters.max_tokens}, \\\"return_likelihoods\\\": \\\"NONE\\\", \\\"truncate\\\": \\\"END\\\", \\\"prompt\\\": \\\"${parameters.prompt}\\\" }\"\n" - + " }\n" - + " ]\n" - + "}"; + + " \"name\": \"Cohere generate text model Connector\",\n" + + " \"description\": \"The connector to public Cohere generate text model service\",\n" + + " \"version\": 1,\n" + + " \"protocol\": \"http\",\n" + + " \"parameters\": {\n" + + " \"endpoint\": \"api.cohere.ai\",\n" + + " \"auth\": \"API_Key\",\n" + + " \"content_type\": \"application/json\",\n" + + " \"max_tokens\": \"20\"\n" + + " },\n" + + " \"credential\": {\n" + + " \"cohere_key\": \"" + + System.getenv("COHERE_KEY") + + "\"\n" + + " },\n" + + " \"actions\": [\n" + + " {\n" + + " \"action_type\": \"predict\",\n" + + " \"method\": \"POST\",\n" + + " \"url\": \"https://${parameters.endpoint}/v1/generate\",\n" + + " \"headers\": { \n" + + " \"Authorization\": \"Bearer ${credential.cohere_key}\"\n" + + " },\n" + + " \"request_body\": \"{ \\\"max_tokens\\\": ${parameters.max_tokens}, \\\"return_likelihoods\\\": \\\"NONE\\\", \\\"truncate\\\": \\\"END\\\", \\\"prompt\\\": \\\"${parameters.prompt}\\\" }\"\n" + + " }\n" + + " ]\n" + + "}"; Response response = createConnector(entity); Map responseMap = parseResponseToMap(response); String connectorId = (String) responseMap.get("connector_id"); @@ -533,31 +511,33 @@ public void testCohereGenerateTextModel() throws IOException, InterruptedExcepti @Ignore public void testCohereClassifyModel() throws IOException, InterruptedException { String entity = "{\n" - + " \"name\": \"Cohere classify model Connector\",\n" - + " \"description\": \"The connector to public Cohere classify model service\",\n" - + " \"version\": 1,\n" - + " \"protocol\": \"http\",\n" - + " \"parameters\": {\n" - + " \"endpoint\": \"api.cohere.ai\",\n" - + " \"auth\": \"API_Key\",\n" - + " \"content_type\": \"application/json\",\n" - + " \"max_tokens\": \"20\"\n" - + " },\n" - + " \"credential\": {\n" - + " \"cohere_key\": \"" + System.getenv("COHERE_KEY") + "\"\n" - + " },\n" - + " \"actions\": [\n" - + " {\n" - + " \"action_type\": \"predict\",\n" - + " \"method\": \"POST\",\n" - + " \"url\": \"https://${parameters.endpoint}/v1/classify\",\n" - + " \"headers\": { \n" - + " \"Authorization\": \"Bearer ${credential.cohere_key}\"\n" - + " },\n" - + " \"request_body\": \"{ \\\"inputs\\\": ${parameters.inputs}, \\\"examples\\\": ${parameters.examples}, \\\"truncate\\\": \\\"END\\\" }\"\n" - + " }\n" - + " ]\n" - + "}"; + + " \"name\": \"Cohere classify model Connector\",\n" + + " \"description\": \"The connector to public Cohere classify model service\",\n" + + " \"version\": 1,\n" + + " \"protocol\": \"http\",\n" + + " \"parameters\": {\n" + + " \"endpoint\": \"api.cohere.ai\",\n" + + " \"auth\": \"API_Key\",\n" + + " \"content_type\": \"application/json\",\n" + + " \"max_tokens\": \"20\"\n" + + " },\n" + + " \"credential\": {\n" + + " \"cohere_key\": \"" + + System.getenv("COHERE_KEY") + + "\"\n" + + " },\n" + + " \"actions\": [\n" + + " {\n" + + " \"action_type\": \"predict\",\n" + + " \"method\": \"POST\",\n" + + " \"url\": \"https://${parameters.endpoint}/v1/classify\",\n" + + " \"headers\": { \n" + + " \"Authorization\": \"Bearer ${credential.cohere_key}\"\n" + + " },\n" + + " \"request_body\": \"{ \\\"inputs\\\": ${parameters.inputs}, \\\"examples\\\": ${parameters.examples}, \\\"truncate\\\": \\\"END\\\" }\"\n" + + " }\n" + + " ]\n" + + "}"; Response response = createConnector(entity); Map responseMap = parseResponseToMap(response); String connectorId = (String) responseMap.get("connector_id"); @@ -628,15 +608,7 @@ public void testCohereClassifyModel() throws IOException, InterruptedException { } private Response createConnector(String input) throws IOException { - return TestHelper - .makeRequest( - client(), - "POST", - "/_plugins/_ml/connectors/_create", - null, - TestHelper.toHttpEntity(input), - null - ); + return TestHelper.makeRequest(client(), "POST", "/_plugins/_ml/connectors/_create", null, TestHelper.toHttpEntity(input), null); } private Response registerRemoteModel(String name, String connectorId) throws IOException { @@ -652,71 +624,48 @@ private Response registerRemoteModel(String name, String connectorId) throws IOE null, TestHelper.toHttpEntity(registerModelGroupEntity), null - ); + ); Map responseMap = parseResponseToMap(response); assertEquals((String) responseMap.get("status"), "CREATED"); String modelGroupId = (String) responseMap.get("model_group_id"); String registerModelEntity = "{\n" - + " \"name\": \"" + name + "\",\n" + + " \"name\": \"" + + name + + "\",\n" + " \"function_name\": \"remote\",\n" - + " \"model_group_id\": \"" + modelGroupId + "\",\n" + + " \"model_group_id\": \"" + + modelGroupId + + "\",\n" + " \"version\": \"1.0.0\",\n" + " \"description\": \"test model\",\n" - + " \"connector_id\": \"" + connectorId + "\"\n" + + " \"connector_id\": \"" + + connectorId + + "\"\n" + "}"; return TestHelper - .makeRequest( - client(), - "POST", - "/_plugins/_ml/models/_register", - null, - TestHelper.toHttpEntity(registerModelEntity), - null - ); + .makeRequest(client(), "POST", "/_plugins/_ml/models/_register", null, TestHelper.toHttpEntity(registerModelEntity), null); } private Response deployRemoteModel(String modelId) throws IOException { - return TestHelper - .makeRequest( - client(), - "POST", - "/_plugins/_ml/models/" + modelId + "/_deploy", - null, - "", - null - ); + return TestHelper.makeRequest(client(), "POST", "/_plugins/_ml/models/" + modelId + "/_deploy", null, "", null); } private Response predictRemoteModel(String modelId, String input) throws IOException { - return TestHelper - .makeRequest( - client(), - "POST", - "/_plugins/_ml/models/" + modelId + "/_predict", - null, - input, - null - ); + return TestHelper.makeRequest(client(), "POST", "/_plugins/_ml/models/" + modelId + "/_predict", null, input, null); } private Response undeployRemoteModel(String modelId) throws IOException { String undeployEntity = "{\n" + " \"SYqCMdsFTumUwoHZcsgiUg\": {\n" + " \"stats\": {\n" - + " \"" + modelId + "\": \"undeployed\"\n" + + " \"" + + modelId + + "\": \"undeployed\"\n" + " }\n" + " }\n" + "}"; - return TestHelper - .makeRequest( - client(), - "POST", - "/_plugins/_ml/models/" + modelId + "/_undeploy", - null, - undeployEntity, - null - ); + return TestHelper.makeRequest(client(), "POST", "/_plugins/_ml/models/" + modelId + "/_undeploy", null, undeployEntity, null); } private Map parseResponseToMap(Response response) throws IOException { From d10e4d34a2f1b1d329f3ada632c06a9999c4dd51 Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Tue, 11 Jul 2023 08:18:34 +0000 Subject: [PATCH 03/15] Change loop number in edits test Signed-off-by: Ryan Bogan --- .../java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java index 33665d5fe8..d409c28f05 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java @@ -313,7 +313,7 @@ public void testOpenAIEditsModel() throws IOException, InterruptedException { // Currently, the OpenAI Edits model does not work 100% of the time. This loop tries 3 times to see if the problem is with the model // or with OpenSearch. boolean editsModelSuccess = false; - for (int i = 0; i < 3; i++) { + for (int i = 0; i < 5; i++) { response = predictRemoteModel(modelId, predictInput); responseMap = parseResponseToMap(response); List responseList = (List) responseMap.get("inference_results"); From 2c4fac67b14910cf9db48ff356664022b5c7e789 Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Tue, 11 Jul 2023 08:28:08 +0000 Subject: [PATCH 04/15] Change edits model test again Signed-off-by: Ryan Bogan --- .../ml/rest/RestMLRemoteInferenceIT.java | 495 ++++++++++-------- 1 file changed, 273 insertions(+), 222 deletions(-) diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java index d409c28f05..a746c8324d 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java @@ -12,12 +12,14 @@ import org.apache.http.HttpEntity; import org.apache.http.HttpHeaders; import org.apache.http.message.BasicHeader; +import org.junit.After; import org.junit.Before; import org.junit.Ignore; import org.junit.Rule; import org.junit.rules.ExpectedException; import org.opensearch.client.Response; import org.opensearch.ml.common.MLTaskState; +import org.opensearch.ml.common.utils.StringUtils; import org.opensearch.ml.utils.TestHelper; import com.google.common.collect.ImmutableList; @@ -25,36 +27,34 @@ public class RestMLRemoteInferenceIT extends MLCommonsRestTestCase { private final String completionModelConnectorEntity = "{\n" - + "\"name\": \"OpenAI Connector\",\n" - + "\"description\": \"The connector to public OpenAI model service for GPT 3.5\",\n" - + "\"version\": 1,\n" - + "\"protocol\": \"http\",\n" - + "\"parameters\": {\n" - + " \"endpoint\": \"api.openai.com\",\n" - + " \"auth\": \"API_Key\",\n" - + " \"content_type\": \"application/json\",\n" - + " \"max_tokens\": 7,\n" - + " \"temperature\": 0,\n" - + " \"model\": \"text-davinci-003\"\n" - + " },\n" - + " \"credential\": {\n" - + " \"openAI_key\": \"" - + System.getenv("OPENAI_KEY") - + "\"\n" - + " },\n" - + " \"actions\": [\n" - + " {" - + " \"action_type\": \"predict\",\n" - + " \"method\": \"POST\",\n" - + " \"url\": \"https://${parameters.endpoint}/v1/completions\",\n" - + " \"headers\": {\n" - + " \"Authorization\": \"Bearer ${credential.openAI_key}\"\n" - + " },\n" - + " \"request_body\": \"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"prompt\\\": \\\"${parameters.prompt}\\\", \\\"max_tokens\\\": ${parameters.max_tokens}, \\\"temperature\\\": ${parameters.temperature} }\"\n" - + " }\n" - + " ]\n" - + "}"; - + + "\"name\": \"OpenAI Connector\",\n" + + "\"description\": \"The connector to public OpenAI model service for GPT 3.5\",\n" + + "\"version\": 1,\n" + + "\"protocol\": \"http\",\n" + + "\"parameters\": {\n" + + " \"endpoint\": \"api.openai.com\",\n" + + " \"auth\": \"API_Key\",\n" + + " \"content_type\": \"application/json\",\n" + + " \"max_tokens\": 7,\n" + + " \"temperature\": 0,\n" + + " \"model\": \"text-davinci-003\"\n" + + " },\n" + + " \"credential\": {\n" + + " \"openAI_key\": \"" + System.getenv("OPENAI_KEY") + "\"\n" + + " },\n" + + " \"actions\": [\n" + + " {" + + " \"action_type\": \"predict\",\n" + + " \"method\": \"POST\",\n" + + " \"url\": \"https://${parameters.endpoint}/v1/completions\",\n" + + " \"headers\": {\n" + + " \"Authorization\": \"Bearer ${credential.openAI_key}\"\n" + + " },\n" + + " \"request_body\": \"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"prompt\\\": \\\"${parameters.prompt}\\\", \\\"max_tokens\\\": ${parameters.max_tokens}, \\\"temperature\\\": ${parameters.temperature} }\"\n" + + " }\n" + + " ]\n" + + "}"; + @Rule public ExpectedException exceptionRule = ExpectedException.none(); @@ -71,7 +71,7 @@ private void setEncryptionMasterKey() throws IOException { ); assertEquals(200, response.getStatusLine().getStatusCode()); } - + public void testCreateConnector() throws IOException { System.out.println(System.getenv()); disableClusterConnectorAccessControl(); @@ -87,7 +87,15 @@ public void testGetConnector() throws IOException { Response response = createConnector(completionModelConnectorEntity); Map responseMap = parseResponseToMap(response); String connectorId = (String) responseMap.get("connector_id"); - response = TestHelper.makeRequest(client(), "GET", "/_plugins/_ml/connectors/" + connectorId, null, "", null); + response = TestHelper + .makeRequest( + client(), + "GET", + "/_plugins/_ml/connectors/" + connectorId, + null, + "", + null + ); responseMap = parseResponseToMap(response); assertEquals("OpenAI Connector", (String) responseMap.get("name")); assertEquals("1", (String) responseMap.get("version")); @@ -102,7 +110,15 @@ public void testDeleteConnector() throws IOException { Response response = createConnector(completionModelConnectorEntity); Map responseMap = parseResponseToMap(response); String connectorId = (String) responseMap.get("connector_id"); - response = TestHelper.makeRequest(client(), "DELETE", "/_plugins/_ml/connectors/" + connectorId, null, "", null); + response = TestHelper + .makeRequest( + client(), + "DELETE", + "/_plugins/_ml/connectors/" + connectorId, + null, + "", + null + ); responseMap = parseResponseToMap(response); assertEquals("deleted", (String) responseMap.get("result")); } @@ -111,12 +127,24 @@ public void testSearchConnectors() throws IOException { disableClusterConnectorAccessControl(); setEncryptionMasterKey(); createConnector(completionModelConnectorEntity); - String searchEntity = "{\n" + " \"query\": {\n" + " \"match_all\": {}\n" + " },\n" + " \"size\": 1000\n" + "}"; + String searchEntity = "{\n" + + " \"query\": {\n" + + " \"match_all\": {}\n" + + " },\n" + + " \"size\": 1000\n" + + "}"; Response response = TestHelper - .makeRequest(client(), "GET", "/_plugins/_ml/connectors/_search", null, TestHelper.toHttpEntity(searchEntity), null); + .makeRequest( + client(), + "GET", + "/_plugins/_ml/connectors/_search", + null, + TestHelper.toHttpEntity(searchEntity), + null + ); Map responseMap = parseResponseToMap(response); assertEquals((Double) 1.0, (Double) ((Map) ((Map) responseMap.get("hits")).get("total")).get("value")); - + } public void testRegisterRemoteModel() throws IOException, InterruptedException { @@ -171,7 +199,11 @@ public void testPredictRemoteModel() throws IOException, InterruptedException { responseMap = parseResponseToMap(response); taskId = (String) responseMap.get("task_id"); waitForTask(taskId, MLTaskState.COMPLETED); - String predictInput = "{\n" + " \"parameters\": {\n" + " \"prompt\": \"Say this is a test\"\n" + " }\n" + "}"; + String predictInput = "{\n" + + " \"parameters\": {\n" + + " \"prompt\": \"Say this is a test\"\n" + + " }\n" + + "}"; response = predictRemoteModel(modelId, predictInput); responseMap = parseResponseToMap(response); List responseList = (List) responseMap.get("inference_results"); @@ -222,19 +254,17 @@ public void testOpenAIChatCompletionModel() throws IOException, InterruptedExcep + " \"model\": \"gpt-3.5-turbo\"\n" + " },\n" + " \"credential\": {\n" - + " \"openAI_key\": \"" - + System.getenv("OPENAI_KEY") - + "\"\n" + + " \"openAI_key\": \"" + System.getenv("OPENAI_KEY") + "\"\n" + " },\n" + " \"actions\": [\n" + " {\n" - + " \"action_type\": \"predict\",\n" - + " \"method\": \"POST\",\n" - + " \"url\": \"https://api.openai.com/v1/chat/completions\",\n" - + " \"headers\": { \n" - + " \"Authorization\": \"Bearer ${credential.openAI_key}\"\n" - + " },\n" - + " \"request_body\": \"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"messages\\\": ${parameters.messages} }\"\n" + + " \"action_type\": \"predict\",\n" + + " \"method\": \"POST\",\n" + + " \"url\": \"https://api.openai.com/v1/chat/completions\",\n" + + " \"headers\": { \n" + + " \"Authorization\": \"Bearer ${credential.openAI_key}\"\n" + + " },\n" + + " \"request_body\": \"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"messages\\\": ${parameters.messages} }\"\n" + " }\n" + " ]\n" + "}"; @@ -263,33 +293,31 @@ public void testOpenAIChatCompletionModel() throws IOException, InterruptedExcep public void testOpenAIEditsModel() throws IOException, InterruptedException { String entity = "{\n" - + " \"name\": \"OpenAI Edit model Connector\",\n" - + " \"description\": \"The connector to public OpenAI edit model service\",\n" - + " \"version\": 1,\n" - + " \"protocol\": \"http\",\n" - + " \"parameters\": {\n" - + " \"endpoint\": \"api.openai.com\",\n" - + " \"auth\": \"API_Key\",\n" - + " \"content_type\": \"application/json\",\n" - + " \"model\": \"text-davinci-edit-001\"\n" - + " },\n" - + " \"credential\": {\n" - + " \"openAI_key\": \"" - + System.getenv("OPENAI_KEY") - + "\"\n" - + " },\n" - + " \"actions\": [\n" - + " {\n" - + " \"action_type\": \"predict\",\n" - + " \"method\": \"POST\",\n" - + " \"url\": \"https://api.openai.com/v1/edits\",\n" - + " \"headers\": { \n" - + " \"Authorization\": \"Bearer ${credential.openAI_key}\"\n" - + " },\n" - + " \"request_body\": \"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"input\\\": \\\"${parameters.input}\\\", \\\"instruction\\\": \\\"${parameters.instruction}\\\" }\"\n" - + " }\n" - + " ]\n" - + "}"; + + " \"name\": \"OpenAI Edit model Connector\",\n" + + " \"description\": \"The connector to public OpenAI edit model service\",\n" + + " \"version\": 1,\n" + + " \"protocol\": \"http\",\n" + + " \"parameters\": {\n" + + " \"endpoint\": \"api.openai.com\",\n" + + " \"auth\": \"API_Key\",\n" + + " \"content_type\": \"application/json\",\n" + + " \"model\": \"text-davinci-edit-001\"\n" + + " },\n" + + " \"credential\": {\n" + + " \"openAI_key\": \"" + System.getenv("OPENAI_KEY") + "\"\n" + + " },\n" + + " \"actions\": [\n" + + " {\n" + + " \"action_type\": \"predict\",\n" + + " \"method\": \"POST\",\n" + + " \"url\": \"https://api.openai.com/v1/edits\",\n" + + " \"headers\": { \n" + + " \"Authorization\": \"Bearer ${credential.openAI_key}\"\n" + + " },\n" + + " \"request_body\": \"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"input\\\": \\\"${parameters.input}\\\", \\\"instruction\\\": \\\"${parameters.instruction}\\\" }\"\n" + + " }\n" + + " ]\n" + + "}"; Response response = createConnector(entity); Map responseMap = parseResponseToMap(response); String connectorId = (String) responseMap.get("connector_id"); @@ -310,55 +338,45 @@ public void testOpenAIEditsModel() throws IOException, InterruptedException { + " \"instruction\": \"Fix the spelling mistakes\"\n" + " }\n" + "}"; - // Currently, the OpenAI Edits model does not work 100% of the time. This loop tries 3 times to see if the problem is with the model - // or with OpenSearch. - boolean editsModelSuccess = false; - for (int i = 0; i < 5; i++) { - response = predictRemoteModel(modelId, predictInput); - responseMap = parseResponseToMap(response); - List responseList = (List) responseMap.get("inference_results"); - responseMap = (Map) responseList.get(0); - responseList = (List) responseMap.get("output"); - responseMap = (Map) responseList.get(0); - responseMap = (Map) responseMap.get("dataAsMap"); - responseList = (List) responseMap.get("choices"); - responseMap = (Map) responseList.get(0); - if (((String) responseMap.get("text")).contains("What day of the week is it?")) { - editsModelSuccess = true; - } - } - assertTrue(editsModelSuccess); + response = predictRemoteModel(modelId, predictInput); + responseMap = parseResponseToMap(response); + List responseList = (List) responseMap.get("inference_results"); + responseMap = (Map) responseList.get(0); + responseList = (List) responseMap.get("output"); + responseMap = (Map) responseList.get(0); + responseMap = (Map) responseMap.get("dataAsMap"); + responseList = (List) responseMap.get("choices"); + responseMap = (Map) responseList.get(0); + assertNotNull(((String) responseMap.get("text"))); } public void testOpenAIModerationsModel() throws IOException, InterruptedException { String entity = "{\n" - + " \"name\": \"OpenAI moderations model Connector\",\n" - + " \"description\": \"The connector to public OpenAI moderations model service\",\n" - + " \"version\": 1,\n" - + " \"protocol\": \"http\",\n" - + " \"parameters\": {\n" - + " \"endpoint\": \"api.openai.com\",\n" - + " \"auth\": \"API_Key\",\n" - + " \"content_type\": \"application/json\",\n" - + " \"model\": \"moderations\"\n" - + " },\n" - + " \"credential\": {\n" - + " \"openAI_key\": \"" - + System.getenv("OPENAI_KEY") - + "\"\n" - + " },\n" - + " \"actions\": [\n" - + " {\n" - + " \"action_type\": \"predict\",\n" - + " \"method\": \"POST\",\n" - + " \"url\": \"https://api.openai.com/v1/moderations\",\n" - + " \"headers\": { \n" - + " \"Authorization\": \"Bearer ${credential.openAI_key}\"\n" - + " },\n" - + " \"request_body\": \"{ \\\"input\\\": \\\"${parameters.input}\\\" }\"\n" - + " }\n" - + " ]\n" - + "}"; + + " \"name\": \"OpenAI moderations model Connector\",\n" + + " \"description\": \"The connector to public OpenAI moderations model service\",\n" + + " \"version\": 1,\n" + + " \"protocol\": \"http\",\n" + + " \"parameters\": {\n" + + " \"endpoint\": \"api.openai.com\",\n" + + " \"auth\": \"API_Key\",\n" + + " \"content_type\": \"application/json\",\n" + + " \"model\": \"moderations\"\n" + + " },\n" + + " \"credential\": {\n" + + " \"openAI_key\": \"" + System.getenv("OPENAI_KEY") + "\"\n" + + " },\n" + + " \"actions\": [\n" + + " {\n" + + " \"action_type\": \"predict\",\n" + + " \"method\": \"POST\",\n" + + " \"url\": \"https://api.openai.com/v1/moderations\",\n" + + " \"headers\": { \n" + + " \"Authorization\": \"Bearer ${credential.openAI_key}\"\n" + + " },\n" + + " \"request_body\": \"{ \\\"input\\\": \\\"${parameters.input}\\\" }\"\n" + + " }\n" + + " ]\n" + + "}"; Response response = createConnector(entity); Map responseMap = parseResponseToMap(response); String connectorId = (String) responseMap.get("connector_id"); @@ -373,7 +391,11 @@ public void testOpenAIModerationsModel() throws IOException, InterruptedExceptio responseMap = parseResponseToMap(response); taskId = (String) responseMap.get("task_id"); waitForTask(taskId, MLTaskState.COMPLETED); - String predictInput = "{\n" + " \"parameters\": {\n" + " \"input\": \"I want to kill them.\"\n" + " }\n" + "}"; + String predictInput = "{\n" + + " \"parameters\": {\n" + + " \"input\": \"I want to kill them.\"\n" + + " }\n" + + "}"; response = predictRemoteModel(modelId, predictInput); responseMap = parseResponseToMap(response); List responseList = (List) responseMap.get("inference_results"); @@ -391,35 +413,33 @@ public void testOpenAIModerationsModel() throws IOException, InterruptedExceptio @Ignore public void testOpenAITextEmbeddingModel() throws IOException, InterruptedException { String entity = "{\n" - + " \"name\": \"OpenAI text embedding model Connector\",\n" - + " \"description\": \"The connector to public OpenAI text embedding model service\",\n" - + " \"version\": 1,\n" - + " \"protocol\": \"http\",\n" - + " \"parameters\": {\n" - + " \"endpoint\": \"api.openai.com\",\n" - + " \"auth\": \"API_Key\",\n" - + " \"content_type\": \"application/json\",\n" - + " \"model\": \"text-embedding-ada-002\"\n" - + " },\n" - + " \"credential\": {\n" - + " \"openAI_key\": \"" - + System.getenv("OPENAI_KEY") - + "\"\n" - + " },\n" - + " \"actions\": [\n" - + " {\n" - + " \"action_type\": \"predict\",\n" - + " \"method\": \"POST\",\n" - + " \"url\": \"https://api.openai.com/v1/embeddings\",\n" - + " \"headers\": { \n" - + " \"Authorization\": \"Bearer ${credential.openAI_key}\"\n" - + " },\n" - + " \"request_body\": \"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"input\\\": \\\"${parameters.input}\\\" }\",\n" - + " \"pre_process_function\": \"text_docs_to_openai_embedding_input\",\n" - + " \"post_process_function\": \"openai_embedding\"\n" - + " }\n" - + " ]\n" - + "}"; + + " \"name\": \"OpenAI text embedding model Connector\",\n" + + " \"description\": \"The connector to public OpenAI text embedding model service\",\n" + + " \"version\": 1,\n" + + " \"protocol\": \"http\",\n" + + " \"parameters\": {\n" + + " \"endpoint\": \"api.openai.com\",\n" + + " \"auth\": \"API_Key\",\n" + + " \"content_type\": \"application/json\",\n" + + " \"model\": \"text-embedding-ada-002\"\n" + + " },\n" + + " \"credential\": {\n" + + " \"openAI_key\": \"" + System.getenv("OPENAI_KEY") + "\"\n" + + " },\n" + + " \"actions\": [\n" + + " {\n" + + " \"action_type\": \"predict\",\n" + + " \"method\": \"POST\",\n" + + " \"url\": \"https://api.openai.com/v1/embeddings\",\n" + + " \"headers\": { \n" + + " \"Authorization\": \"Bearer ${credential.openAI_key}\"\n" + + " },\n" + + " \"request_body\": \"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"input\\\": \\\"${parameters.input}\\\" }\",\n" + + " \"pre_process_function\": \"text_docs_to_openai_embedding_input\",\n" + + " \"post_process_function\": \"openai_embedding\"\n" + + " }\n" + + " ]\n" + + "}"; Response response = createConnector(entity); Map responseMap = parseResponseToMap(response); String connectorId = (String) responseMap.get("connector_id"); @@ -434,7 +454,11 @@ public void testOpenAITextEmbeddingModel() throws IOException, InterruptedExcept responseMap = parseResponseToMap(response); taskId = (String) responseMap.get("task_id"); waitForTask(taskId, MLTaskState.COMPLETED); - String predictInput = "{\n" + " \"parameters\": {\n" + " \"input\": \"The food was delicious\"\n" + " }\n" + "}"; + String predictInput = "{\n" + + " \"parameters\": {\n" + + " \"input\": \"The food was delicious\"\n" + + " }\n" + + "}"; response = predictRemoteModel(modelId, predictInput); responseMap = parseResponseToMap(response); List responseList = (List) responseMap.get("inference_results"); @@ -449,33 +473,31 @@ public void testOpenAITextEmbeddingModel() throws IOException, InterruptedExcept public void testCohereGenerateTextModel() throws IOException, InterruptedException { String entity = "{\n" - + " \"name\": \"Cohere generate text model Connector\",\n" - + " \"description\": \"The connector to public Cohere generate text model service\",\n" - + " \"version\": 1,\n" - + " \"protocol\": \"http\",\n" - + " \"parameters\": {\n" - + " \"endpoint\": \"api.cohere.ai\",\n" - + " \"auth\": \"API_Key\",\n" - + " \"content_type\": \"application/json\",\n" - + " \"max_tokens\": \"20\"\n" - + " },\n" - + " \"credential\": {\n" - + " \"cohere_key\": \"" - + System.getenv("COHERE_KEY") - + "\"\n" - + " },\n" - + " \"actions\": [\n" - + " {\n" - + " \"action_type\": \"predict\",\n" - + " \"method\": \"POST\",\n" - + " \"url\": \"https://${parameters.endpoint}/v1/generate\",\n" - + " \"headers\": { \n" - + " \"Authorization\": \"Bearer ${credential.cohere_key}\"\n" - + " },\n" - + " \"request_body\": \"{ \\\"max_tokens\\\": ${parameters.max_tokens}, \\\"return_likelihoods\\\": \\\"NONE\\\", \\\"truncate\\\": \\\"END\\\", \\\"prompt\\\": \\\"${parameters.prompt}\\\" }\"\n" - + " }\n" - + " ]\n" - + "}"; + + " \"name\": \"Cohere generate text model Connector\",\n" + + " \"description\": \"The connector to public Cohere generate text model service\",\n" + + " \"version\": 1,\n" + + " \"protocol\": \"http\",\n" + + " \"parameters\": {\n" + + " \"endpoint\": \"api.cohere.ai\",\n" + + " \"auth\": \"API_Key\",\n" + + " \"content_type\": \"application/json\",\n" + + " \"max_tokens\": \"20\"\n" + + " },\n" + + " \"credential\": {\n" + + " \"cohere_key\": \"" + System.getenv("COHERE_KEY") + "\"\n" + + " },\n" + + " \"actions\": [\n" + + " {\n" + + " \"action_type\": \"predict\",\n" + + " \"method\": \"POST\",\n" + + " \"url\": \"https://${parameters.endpoint}/v1/generate\",\n" + + " \"headers\": { \n" + + " \"Authorization\": \"Bearer ${credential.cohere_key}\"\n" + + " },\n" + + " \"request_body\": \"{ \\\"max_tokens\\\": ${parameters.max_tokens}, \\\"return_likelihoods\\\": \\\"NONE\\\", \\\"truncate\\\": \\\"END\\\", \\\"prompt\\\": \\\"${parameters.prompt}\\\" }\"\n" + + " }\n" + + " ]\n" + + "}"; Response response = createConnector(entity); Map responseMap = parseResponseToMap(response); String connectorId = (String) responseMap.get("connector_id"); @@ -511,33 +533,31 @@ public void testCohereGenerateTextModel() throws IOException, InterruptedExcepti @Ignore public void testCohereClassifyModel() throws IOException, InterruptedException { String entity = "{\n" - + " \"name\": \"Cohere classify model Connector\",\n" - + " \"description\": \"The connector to public Cohere classify model service\",\n" - + " \"version\": 1,\n" - + " \"protocol\": \"http\",\n" - + " \"parameters\": {\n" - + " \"endpoint\": \"api.cohere.ai\",\n" - + " \"auth\": \"API_Key\",\n" - + " \"content_type\": \"application/json\",\n" - + " \"max_tokens\": \"20\"\n" - + " },\n" - + " \"credential\": {\n" - + " \"cohere_key\": \"" - + System.getenv("COHERE_KEY") - + "\"\n" - + " },\n" - + " \"actions\": [\n" - + " {\n" - + " \"action_type\": \"predict\",\n" - + " \"method\": \"POST\",\n" - + " \"url\": \"https://${parameters.endpoint}/v1/classify\",\n" - + " \"headers\": { \n" - + " \"Authorization\": \"Bearer ${credential.cohere_key}\"\n" - + " },\n" - + " \"request_body\": \"{ \\\"inputs\\\": ${parameters.inputs}, \\\"examples\\\": ${parameters.examples}, \\\"truncate\\\": \\\"END\\\" }\"\n" - + " }\n" - + " ]\n" - + "}"; + + " \"name\": \"Cohere classify model Connector\",\n" + + " \"description\": \"The connector to public Cohere classify model service\",\n" + + " \"version\": 1,\n" + + " \"protocol\": \"http\",\n" + + " \"parameters\": {\n" + + " \"endpoint\": \"api.cohere.ai\",\n" + + " \"auth\": \"API_Key\",\n" + + " \"content_type\": \"application/json\",\n" + + " \"max_tokens\": \"20\"\n" + + " },\n" + + " \"credential\": {\n" + + " \"cohere_key\": \"" + System.getenv("COHERE_KEY") + "\"\n" + + " },\n" + + " \"actions\": [\n" + + " {\n" + + " \"action_type\": \"predict\",\n" + + " \"method\": \"POST\",\n" + + " \"url\": \"https://${parameters.endpoint}/v1/classify\",\n" + + " \"headers\": { \n" + + " \"Authorization\": \"Bearer ${credential.cohere_key}\"\n" + + " },\n" + + " \"request_body\": \"{ \\\"inputs\\\": ${parameters.inputs}, \\\"examples\\\": ${parameters.examples}, \\\"truncate\\\": \\\"END\\\" }\"\n" + + " }\n" + + " ]\n" + + "}"; Response response = createConnector(entity); Map responseMap = parseResponseToMap(response); String connectorId = (String) responseMap.get("connector_id"); @@ -608,7 +628,15 @@ public void testCohereClassifyModel() throws IOException, InterruptedException { } private Response createConnector(String input) throws IOException { - return TestHelper.makeRequest(client(), "POST", "/_plugins/_ml/connectors/_create", null, TestHelper.toHttpEntity(input), null); + return TestHelper + .makeRequest( + client(), + "POST", + "/_plugins/_ml/connectors/_create", + null, + TestHelper.toHttpEntity(input), + null + ); } private Response registerRemoteModel(String name, String connectorId) throws IOException { @@ -624,48 +652,71 @@ private Response registerRemoteModel(String name, String connectorId) throws IOE null, TestHelper.toHttpEntity(registerModelGroupEntity), null - ); + ); Map responseMap = parseResponseToMap(response); assertEquals((String) responseMap.get("status"), "CREATED"); String modelGroupId = (String) responseMap.get("model_group_id"); String registerModelEntity = "{\n" - + " \"name\": \"" - + name - + "\",\n" + + " \"name\": \"" + name + "\",\n" + " \"function_name\": \"remote\",\n" - + " \"model_group_id\": \"" - + modelGroupId - + "\",\n" + + " \"model_group_id\": \"" + modelGroupId + "\",\n" + " \"version\": \"1.0.0\",\n" + " \"description\": \"test model\",\n" - + " \"connector_id\": \"" - + connectorId - + "\"\n" + + " \"connector_id\": \"" + connectorId + "\"\n" + "}"; return TestHelper - .makeRequest(client(), "POST", "/_plugins/_ml/models/_register", null, TestHelper.toHttpEntity(registerModelEntity), null); + .makeRequest( + client(), + "POST", + "/_plugins/_ml/models/_register", + null, + TestHelper.toHttpEntity(registerModelEntity), + null + ); } private Response deployRemoteModel(String modelId) throws IOException { - return TestHelper.makeRequest(client(), "POST", "/_plugins/_ml/models/" + modelId + "/_deploy", null, "", null); + return TestHelper + .makeRequest( + client(), + "POST", + "/_plugins/_ml/models/" + modelId + "/_deploy", + null, + "", + null + ); } private Response predictRemoteModel(String modelId, String input) throws IOException { - return TestHelper.makeRequest(client(), "POST", "/_plugins/_ml/models/" + modelId + "/_predict", null, input, null); + return TestHelper + .makeRequest( + client(), + "POST", + "/_plugins/_ml/models/" + modelId + "/_predict", + null, + input, + null + ); } private Response undeployRemoteModel(String modelId) throws IOException { String undeployEntity = "{\n" + " \"SYqCMdsFTumUwoHZcsgiUg\": {\n" + " \"stats\": {\n" - + " \"" - + modelId - + "\": \"undeployed\"\n" + + " \"" + modelId + "\": \"undeployed\"\n" + " }\n" + " }\n" + "}"; - return TestHelper.makeRequest(client(), "POST", "/_plugins/_ml/models/" + modelId + "/_undeploy", null, undeployEntity, null); + return TestHelper + .makeRequest( + client(), + "POST", + "/_plugins/_ml/models/" + modelId + "/_undeploy", + null, + undeployEntity, + null + ); } private Map parseResponseToMap(Response response) throws IOException { From 505087afea55df27dfa5b398aac911ac3077b681 Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Tue, 11 Jul 2023 15:06:27 +0000 Subject: [PATCH 05/15] Fix spotless Signed-off-by: Ryan Bogan --- .../ml/rest/RestMLRemoteInferenceIT.java | 467 ++++++++---------- 1 file changed, 204 insertions(+), 263 deletions(-) diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java index a746c8324d..64aa903cf0 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java @@ -12,14 +12,12 @@ import org.apache.http.HttpEntity; import org.apache.http.HttpHeaders; import org.apache.http.message.BasicHeader; -import org.junit.After; import org.junit.Before; import org.junit.Ignore; import org.junit.Rule; import org.junit.rules.ExpectedException; import org.opensearch.client.Response; import org.opensearch.ml.common.MLTaskState; -import org.opensearch.ml.common.utils.StringUtils; import org.opensearch.ml.utils.TestHelper; import com.google.common.collect.ImmutableList; @@ -27,34 +25,36 @@ public class RestMLRemoteInferenceIT extends MLCommonsRestTestCase { private final String completionModelConnectorEntity = "{\n" - + "\"name\": \"OpenAI Connector\",\n" - + "\"description\": \"The connector to public OpenAI model service for GPT 3.5\",\n" - + "\"version\": 1,\n" - + "\"protocol\": \"http\",\n" - + "\"parameters\": {\n" - + " \"endpoint\": \"api.openai.com\",\n" - + " \"auth\": \"API_Key\",\n" - + " \"content_type\": \"application/json\",\n" - + " \"max_tokens\": 7,\n" - + " \"temperature\": 0,\n" - + " \"model\": \"text-davinci-003\"\n" - + " },\n" - + " \"credential\": {\n" - + " \"openAI_key\": \"" + System.getenv("OPENAI_KEY") + "\"\n" - + " },\n" - + " \"actions\": [\n" - + " {" - + " \"action_type\": \"predict\",\n" - + " \"method\": \"POST\",\n" - + " \"url\": \"https://${parameters.endpoint}/v1/completions\",\n" - + " \"headers\": {\n" - + " \"Authorization\": \"Bearer ${credential.openAI_key}\"\n" - + " },\n" - + " \"request_body\": \"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"prompt\\\": \\\"${parameters.prompt}\\\", \\\"max_tokens\\\": ${parameters.max_tokens}, \\\"temperature\\\": ${parameters.temperature} }\"\n" - + " }\n" - + " ]\n" - + "}"; - + + "\"name\": \"OpenAI Connector\",\n" + + "\"description\": \"The connector to public OpenAI model service for GPT 3.5\",\n" + + "\"version\": 1,\n" + + "\"protocol\": \"http\",\n" + + "\"parameters\": {\n" + + " \"endpoint\": \"api.openai.com\",\n" + + " \"auth\": \"API_Key\",\n" + + " \"content_type\": \"application/json\",\n" + + " \"max_tokens\": 7,\n" + + " \"temperature\": 0,\n" + + " \"model\": \"text-davinci-003\"\n" + + " },\n" + + " \"credential\": {\n" + + " \"openAI_key\": \"" + + System.getenv("OPENAI_KEY") + + "\"\n" + + " },\n" + + " \"actions\": [\n" + + " {" + + " \"action_type\": \"predict\",\n" + + " \"method\": \"POST\",\n" + + " \"url\": \"https://${parameters.endpoint}/v1/completions\",\n" + + " \"headers\": {\n" + + " \"Authorization\": \"Bearer ${credential.openAI_key}\"\n" + + " },\n" + + " \"request_body\": \"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"prompt\\\": \\\"${parameters.prompt}\\\", \\\"max_tokens\\\": ${parameters.max_tokens}, \\\"temperature\\\": ${parameters.temperature} }\"\n" + + " }\n" + + " ]\n" + + "}"; + @Rule public ExpectedException exceptionRule = ExpectedException.none(); @@ -71,7 +71,7 @@ private void setEncryptionMasterKey() throws IOException { ); assertEquals(200, response.getStatusLine().getStatusCode()); } - + public void testCreateConnector() throws IOException { System.out.println(System.getenv()); disableClusterConnectorAccessControl(); @@ -87,15 +87,7 @@ public void testGetConnector() throws IOException { Response response = createConnector(completionModelConnectorEntity); Map responseMap = parseResponseToMap(response); String connectorId = (String) responseMap.get("connector_id"); - response = TestHelper - .makeRequest( - client(), - "GET", - "/_plugins/_ml/connectors/" + connectorId, - null, - "", - null - ); + response = TestHelper.makeRequest(client(), "GET", "/_plugins/_ml/connectors/" + connectorId, null, "", null); responseMap = parseResponseToMap(response); assertEquals("OpenAI Connector", (String) responseMap.get("name")); assertEquals("1", (String) responseMap.get("version")); @@ -110,15 +102,7 @@ public void testDeleteConnector() throws IOException { Response response = createConnector(completionModelConnectorEntity); Map responseMap = parseResponseToMap(response); String connectorId = (String) responseMap.get("connector_id"); - response = TestHelper - .makeRequest( - client(), - "DELETE", - "/_plugins/_ml/connectors/" + connectorId, - null, - "", - null - ); + response = TestHelper.makeRequest(client(), "DELETE", "/_plugins/_ml/connectors/" + connectorId, null, "", null); responseMap = parseResponseToMap(response); assertEquals("deleted", (String) responseMap.get("result")); } @@ -127,24 +111,12 @@ public void testSearchConnectors() throws IOException { disableClusterConnectorAccessControl(); setEncryptionMasterKey(); createConnector(completionModelConnectorEntity); - String searchEntity = "{\n" - + " \"query\": {\n" - + " \"match_all\": {}\n" - + " },\n" - + " \"size\": 1000\n" - + "}"; + String searchEntity = "{\n" + " \"query\": {\n" + " \"match_all\": {}\n" + " },\n" + " \"size\": 1000\n" + "}"; Response response = TestHelper - .makeRequest( - client(), - "GET", - "/_plugins/_ml/connectors/_search", - null, - TestHelper.toHttpEntity(searchEntity), - null - ); + .makeRequest(client(), "GET", "/_plugins/_ml/connectors/_search", null, TestHelper.toHttpEntity(searchEntity), null); Map responseMap = parseResponseToMap(response); assertEquals((Double) 1.0, (Double) ((Map) ((Map) responseMap.get("hits")).get("total")).get("value")); - + } public void testRegisterRemoteModel() throws IOException, InterruptedException { @@ -199,11 +171,7 @@ public void testPredictRemoteModel() throws IOException, InterruptedException { responseMap = parseResponseToMap(response); taskId = (String) responseMap.get("task_id"); waitForTask(taskId, MLTaskState.COMPLETED); - String predictInput = "{\n" - + " \"parameters\": {\n" - + " \"prompt\": \"Say this is a test\"\n" - + " }\n" - + "}"; + String predictInput = "{\n" + " \"parameters\": {\n" + " \"prompt\": \"Say this is a test\"\n" + " }\n" + "}"; response = predictRemoteModel(modelId, predictInput); responseMap = parseResponseToMap(response); List responseList = (List) responseMap.get("inference_results"); @@ -254,17 +222,19 @@ public void testOpenAIChatCompletionModel() throws IOException, InterruptedExcep + " \"model\": \"gpt-3.5-turbo\"\n" + " },\n" + " \"credential\": {\n" - + " \"openAI_key\": \"" + System.getenv("OPENAI_KEY") + "\"\n" + + " \"openAI_key\": \"" + + System.getenv("OPENAI_KEY") + + "\"\n" + " },\n" + " \"actions\": [\n" + " {\n" - + " \"action_type\": \"predict\",\n" - + " \"method\": \"POST\",\n" - + " \"url\": \"https://api.openai.com/v1/chat/completions\",\n" - + " \"headers\": { \n" - + " \"Authorization\": \"Bearer ${credential.openAI_key}\"\n" - + " },\n" - + " \"request_body\": \"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"messages\\\": ${parameters.messages} }\"\n" + + " \"action_type\": \"predict\",\n" + + " \"method\": \"POST\",\n" + + " \"url\": \"https://api.openai.com/v1/chat/completions\",\n" + + " \"headers\": { \n" + + " \"Authorization\": \"Bearer ${credential.openAI_key}\"\n" + + " },\n" + + " \"request_body\": \"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"messages\\\": ${parameters.messages} }\"\n" + " }\n" + " ]\n" + "}"; @@ -293,31 +263,33 @@ public void testOpenAIChatCompletionModel() throws IOException, InterruptedExcep public void testOpenAIEditsModel() throws IOException, InterruptedException { String entity = "{\n" - + " \"name\": \"OpenAI Edit model Connector\",\n" - + " \"description\": \"The connector to public OpenAI edit model service\",\n" - + " \"version\": 1,\n" - + " \"protocol\": \"http\",\n" - + " \"parameters\": {\n" - + " \"endpoint\": \"api.openai.com\",\n" - + " \"auth\": \"API_Key\",\n" - + " \"content_type\": \"application/json\",\n" - + " \"model\": \"text-davinci-edit-001\"\n" - + " },\n" - + " \"credential\": {\n" - + " \"openAI_key\": \"" + System.getenv("OPENAI_KEY") + "\"\n" - + " },\n" - + " \"actions\": [\n" - + " {\n" - + " \"action_type\": \"predict\",\n" - + " \"method\": \"POST\",\n" - + " \"url\": \"https://api.openai.com/v1/edits\",\n" - + " \"headers\": { \n" - + " \"Authorization\": \"Bearer ${credential.openAI_key}\"\n" - + " },\n" - + " \"request_body\": \"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"input\\\": \\\"${parameters.input}\\\", \\\"instruction\\\": \\\"${parameters.instruction}\\\" }\"\n" - + " }\n" - + " ]\n" - + "}"; + + " \"name\": \"OpenAI Edit model Connector\",\n" + + " \"description\": \"The connector to public OpenAI edit model service\",\n" + + " \"version\": 1,\n" + + " \"protocol\": \"http\",\n" + + " \"parameters\": {\n" + + " \"endpoint\": \"api.openai.com\",\n" + + " \"auth\": \"API_Key\",\n" + + " \"content_type\": \"application/json\",\n" + + " \"model\": \"text-davinci-edit-001\"\n" + + " },\n" + + " \"credential\": {\n" + + " \"openAI_key\": \"" + + System.getenv("OPENAI_KEY") + + "\"\n" + + " },\n" + + " \"actions\": [\n" + + " {\n" + + " \"action_type\": \"predict\",\n" + + " \"method\": \"POST\",\n" + + " \"url\": \"https://api.openai.com/v1/edits\",\n" + + " \"headers\": { \n" + + " \"Authorization\": \"Bearer ${credential.openAI_key}\"\n" + + " },\n" + + " \"request_body\": \"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"input\\\": \\\"${parameters.input}\\\", \\\"instruction\\\": \\\"${parameters.instruction}\\\" }\"\n" + + " }\n" + + " ]\n" + + "}"; Response response = createConnector(entity); Map responseMap = parseResponseToMap(response); String connectorId = (String) responseMap.get("connector_id"); @@ -352,31 +324,33 @@ public void testOpenAIEditsModel() throws IOException, InterruptedException { public void testOpenAIModerationsModel() throws IOException, InterruptedException { String entity = "{\n" - + " \"name\": \"OpenAI moderations model Connector\",\n" - + " \"description\": \"The connector to public OpenAI moderations model service\",\n" - + " \"version\": 1,\n" - + " \"protocol\": \"http\",\n" - + " \"parameters\": {\n" - + " \"endpoint\": \"api.openai.com\",\n" - + " \"auth\": \"API_Key\",\n" - + " \"content_type\": \"application/json\",\n" - + " \"model\": \"moderations\"\n" - + " },\n" - + " \"credential\": {\n" - + " \"openAI_key\": \"" + System.getenv("OPENAI_KEY") + "\"\n" - + " },\n" - + " \"actions\": [\n" - + " {\n" - + " \"action_type\": \"predict\",\n" - + " \"method\": \"POST\",\n" - + " \"url\": \"https://api.openai.com/v1/moderations\",\n" - + " \"headers\": { \n" - + " \"Authorization\": \"Bearer ${credential.openAI_key}\"\n" - + " },\n" - + " \"request_body\": \"{ \\\"input\\\": \\\"${parameters.input}\\\" }\"\n" - + " }\n" - + " ]\n" - + "}"; + + " \"name\": \"OpenAI moderations model Connector\",\n" + + " \"description\": \"The connector to public OpenAI moderations model service\",\n" + + " \"version\": 1,\n" + + " \"protocol\": \"http\",\n" + + " \"parameters\": {\n" + + " \"endpoint\": \"api.openai.com\",\n" + + " \"auth\": \"API_Key\",\n" + + " \"content_type\": \"application/json\",\n" + + " \"model\": \"moderations\"\n" + + " },\n" + + " \"credential\": {\n" + + " \"openAI_key\": \"" + + System.getenv("OPENAI_KEY") + + "\"\n" + + " },\n" + + " \"actions\": [\n" + + " {\n" + + " \"action_type\": \"predict\",\n" + + " \"method\": \"POST\",\n" + + " \"url\": \"https://api.openai.com/v1/moderations\",\n" + + " \"headers\": { \n" + + " \"Authorization\": \"Bearer ${credential.openAI_key}\"\n" + + " },\n" + + " \"request_body\": \"{ \\\"input\\\": \\\"${parameters.input}\\\" }\"\n" + + " }\n" + + " ]\n" + + "}"; Response response = createConnector(entity); Map responseMap = parseResponseToMap(response); String connectorId = (String) responseMap.get("connector_id"); @@ -391,11 +365,7 @@ public void testOpenAIModerationsModel() throws IOException, InterruptedExceptio responseMap = parseResponseToMap(response); taskId = (String) responseMap.get("task_id"); waitForTask(taskId, MLTaskState.COMPLETED); - String predictInput = "{\n" - + " \"parameters\": {\n" - + " \"input\": \"I want to kill them.\"\n" - + " }\n" - + "}"; + String predictInput = "{\n" + " \"parameters\": {\n" + " \"input\": \"I want to kill them.\"\n" + " }\n" + "}"; response = predictRemoteModel(modelId, predictInput); responseMap = parseResponseToMap(response); List responseList = (List) responseMap.get("inference_results"); @@ -413,33 +383,35 @@ public void testOpenAIModerationsModel() throws IOException, InterruptedExceptio @Ignore public void testOpenAITextEmbeddingModel() throws IOException, InterruptedException { String entity = "{\n" - + " \"name\": \"OpenAI text embedding model Connector\",\n" - + " \"description\": \"The connector to public OpenAI text embedding model service\",\n" - + " \"version\": 1,\n" - + " \"protocol\": \"http\",\n" - + " \"parameters\": {\n" - + " \"endpoint\": \"api.openai.com\",\n" - + " \"auth\": \"API_Key\",\n" - + " \"content_type\": \"application/json\",\n" - + " \"model\": \"text-embedding-ada-002\"\n" - + " },\n" - + " \"credential\": {\n" - + " \"openAI_key\": \"" + System.getenv("OPENAI_KEY") + "\"\n" - + " },\n" - + " \"actions\": [\n" - + " {\n" - + " \"action_type\": \"predict\",\n" - + " \"method\": \"POST\",\n" - + " \"url\": \"https://api.openai.com/v1/embeddings\",\n" - + " \"headers\": { \n" - + " \"Authorization\": \"Bearer ${credential.openAI_key}\"\n" - + " },\n" - + " \"request_body\": \"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"input\\\": \\\"${parameters.input}\\\" }\",\n" - + " \"pre_process_function\": \"text_docs_to_openai_embedding_input\",\n" - + " \"post_process_function\": \"openai_embedding\"\n" - + " }\n" - + " ]\n" - + "}"; + + " \"name\": \"OpenAI text embedding model Connector\",\n" + + " \"description\": \"The connector to public OpenAI text embedding model service\",\n" + + " \"version\": 1,\n" + + " \"protocol\": \"http\",\n" + + " \"parameters\": {\n" + + " \"endpoint\": \"api.openai.com\",\n" + + " \"auth\": \"API_Key\",\n" + + " \"content_type\": \"application/json\",\n" + + " \"model\": \"text-embedding-ada-002\"\n" + + " },\n" + + " \"credential\": {\n" + + " \"openAI_key\": \"" + + System.getenv("OPENAI_KEY") + + "\"\n" + + " },\n" + + " \"actions\": [\n" + + " {\n" + + " \"action_type\": \"predict\",\n" + + " \"method\": \"POST\",\n" + + " \"url\": \"https://api.openai.com/v1/embeddings\",\n" + + " \"headers\": { \n" + + " \"Authorization\": \"Bearer ${credential.openAI_key}\"\n" + + " },\n" + + " \"request_body\": \"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"input\\\": \\\"${parameters.input}\\\" }\",\n" + + " \"pre_process_function\": \"text_docs_to_openai_embedding_input\",\n" + + " \"post_process_function\": \"openai_embedding\"\n" + + " }\n" + + " ]\n" + + "}"; Response response = createConnector(entity); Map responseMap = parseResponseToMap(response); String connectorId = (String) responseMap.get("connector_id"); @@ -454,11 +426,7 @@ public void testOpenAITextEmbeddingModel() throws IOException, InterruptedExcept responseMap = parseResponseToMap(response); taskId = (String) responseMap.get("task_id"); waitForTask(taskId, MLTaskState.COMPLETED); - String predictInput = "{\n" - + " \"parameters\": {\n" - + " \"input\": \"The food was delicious\"\n" - + " }\n" - + "}"; + String predictInput = "{\n" + " \"parameters\": {\n" + " \"input\": \"The food was delicious\"\n" + " }\n" + "}"; response = predictRemoteModel(modelId, predictInput); responseMap = parseResponseToMap(response); List responseList = (List) responseMap.get("inference_results"); @@ -473,31 +441,33 @@ public void testOpenAITextEmbeddingModel() throws IOException, InterruptedExcept public void testCohereGenerateTextModel() throws IOException, InterruptedException { String entity = "{\n" - + " \"name\": \"Cohere generate text model Connector\",\n" - + " \"description\": \"The connector to public Cohere generate text model service\",\n" - + " \"version\": 1,\n" - + " \"protocol\": \"http\",\n" - + " \"parameters\": {\n" - + " \"endpoint\": \"api.cohere.ai\",\n" - + " \"auth\": \"API_Key\",\n" - + " \"content_type\": \"application/json\",\n" - + " \"max_tokens\": \"20\"\n" - + " },\n" - + " \"credential\": {\n" - + " \"cohere_key\": \"" + System.getenv("COHERE_KEY") + "\"\n" - + " },\n" - + " \"actions\": [\n" - + " {\n" - + " \"action_type\": \"predict\",\n" - + " \"method\": \"POST\",\n" - + " \"url\": \"https://${parameters.endpoint}/v1/generate\",\n" - + " \"headers\": { \n" - + " \"Authorization\": \"Bearer ${credential.cohere_key}\"\n" - + " },\n" - + " \"request_body\": \"{ \\\"max_tokens\\\": ${parameters.max_tokens}, \\\"return_likelihoods\\\": \\\"NONE\\\", \\\"truncate\\\": \\\"END\\\", \\\"prompt\\\": \\\"${parameters.prompt}\\\" }\"\n" - + " }\n" - + " ]\n" - + "}"; + + " \"name\": \"Cohere generate text model Connector\",\n" + + " \"description\": \"The connector to public Cohere generate text model service\",\n" + + " \"version\": 1,\n" + + " \"protocol\": \"http\",\n" + + " \"parameters\": {\n" + + " \"endpoint\": \"api.cohere.ai\",\n" + + " \"auth\": \"API_Key\",\n" + + " \"content_type\": \"application/json\",\n" + + " \"max_tokens\": \"20\"\n" + + " },\n" + + " \"credential\": {\n" + + " \"cohere_key\": \"" + + System.getenv("COHERE_KEY") + + "\"\n" + + " },\n" + + " \"actions\": [\n" + + " {\n" + + " \"action_type\": \"predict\",\n" + + " \"method\": \"POST\",\n" + + " \"url\": \"https://${parameters.endpoint}/v1/generate\",\n" + + " \"headers\": { \n" + + " \"Authorization\": \"Bearer ${credential.cohere_key}\"\n" + + " },\n" + + " \"request_body\": \"{ \\\"max_tokens\\\": ${parameters.max_tokens}, \\\"return_likelihoods\\\": \\\"NONE\\\", \\\"truncate\\\": \\\"END\\\", \\\"prompt\\\": \\\"${parameters.prompt}\\\" }\"\n" + + " }\n" + + " ]\n" + + "}"; Response response = createConnector(entity); Map responseMap = parseResponseToMap(response); String connectorId = (String) responseMap.get("connector_id"); @@ -533,31 +503,33 @@ public void testCohereGenerateTextModel() throws IOException, InterruptedExcepti @Ignore public void testCohereClassifyModel() throws IOException, InterruptedException { String entity = "{\n" - + " \"name\": \"Cohere classify model Connector\",\n" - + " \"description\": \"The connector to public Cohere classify model service\",\n" - + " \"version\": 1,\n" - + " \"protocol\": \"http\",\n" - + " \"parameters\": {\n" - + " \"endpoint\": \"api.cohere.ai\",\n" - + " \"auth\": \"API_Key\",\n" - + " \"content_type\": \"application/json\",\n" - + " \"max_tokens\": \"20\"\n" - + " },\n" - + " \"credential\": {\n" - + " \"cohere_key\": \"" + System.getenv("COHERE_KEY") + "\"\n" - + " },\n" - + " \"actions\": [\n" - + " {\n" - + " \"action_type\": \"predict\",\n" - + " \"method\": \"POST\",\n" - + " \"url\": \"https://${parameters.endpoint}/v1/classify\",\n" - + " \"headers\": { \n" - + " \"Authorization\": \"Bearer ${credential.cohere_key}\"\n" - + " },\n" - + " \"request_body\": \"{ \\\"inputs\\\": ${parameters.inputs}, \\\"examples\\\": ${parameters.examples}, \\\"truncate\\\": \\\"END\\\" }\"\n" - + " }\n" - + " ]\n" - + "}"; + + " \"name\": \"Cohere classify model Connector\",\n" + + " \"description\": \"The connector to public Cohere classify model service\",\n" + + " \"version\": 1,\n" + + " \"protocol\": \"http\",\n" + + " \"parameters\": {\n" + + " \"endpoint\": \"api.cohere.ai\",\n" + + " \"auth\": \"API_Key\",\n" + + " \"content_type\": \"application/json\",\n" + + " \"max_tokens\": \"20\"\n" + + " },\n" + + " \"credential\": {\n" + + " \"cohere_key\": \"" + + System.getenv("COHERE_KEY") + + "\"\n" + + " },\n" + + " \"actions\": [\n" + + " {\n" + + " \"action_type\": \"predict\",\n" + + " \"method\": \"POST\",\n" + + " \"url\": \"https://${parameters.endpoint}/v1/classify\",\n" + + " \"headers\": { \n" + + " \"Authorization\": \"Bearer ${credential.cohere_key}\"\n" + + " },\n" + + " \"request_body\": \"{ \\\"inputs\\\": ${parameters.inputs}, \\\"examples\\\": ${parameters.examples}, \\\"truncate\\\": \\\"END\\\" }\"\n" + + " }\n" + + " ]\n" + + "}"; Response response = createConnector(entity); Map responseMap = parseResponseToMap(response); String connectorId = (String) responseMap.get("connector_id"); @@ -628,15 +600,7 @@ public void testCohereClassifyModel() throws IOException, InterruptedException { } private Response createConnector(String input) throws IOException { - return TestHelper - .makeRequest( - client(), - "POST", - "/_plugins/_ml/connectors/_create", - null, - TestHelper.toHttpEntity(input), - null - ); + return TestHelper.makeRequest(client(), "POST", "/_plugins/_ml/connectors/_create", null, TestHelper.toHttpEntity(input), null); } private Response registerRemoteModel(String name, String connectorId) throws IOException { @@ -652,71 +616,48 @@ private Response registerRemoteModel(String name, String connectorId) throws IOE null, TestHelper.toHttpEntity(registerModelGroupEntity), null - ); + ); Map responseMap = parseResponseToMap(response); assertEquals((String) responseMap.get("status"), "CREATED"); String modelGroupId = (String) responseMap.get("model_group_id"); String registerModelEntity = "{\n" - + " \"name\": \"" + name + "\",\n" + + " \"name\": \"" + + name + + "\",\n" + " \"function_name\": \"remote\",\n" - + " \"model_group_id\": \"" + modelGroupId + "\",\n" + + " \"model_group_id\": \"" + + modelGroupId + + "\",\n" + " \"version\": \"1.0.0\",\n" + " \"description\": \"test model\",\n" - + " \"connector_id\": \"" + connectorId + "\"\n" + + " \"connector_id\": \"" + + connectorId + + "\"\n" + "}"; return TestHelper - .makeRequest( - client(), - "POST", - "/_plugins/_ml/models/_register", - null, - TestHelper.toHttpEntity(registerModelEntity), - null - ); + .makeRequest(client(), "POST", "/_plugins/_ml/models/_register", null, TestHelper.toHttpEntity(registerModelEntity), null); } private Response deployRemoteModel(String modelId) throws IOException { - return TestHelper - .makeRequest( - client(), - "POST", - "/_plugins/_ml/models/" + modelId + "/_deploy", - null, - "", - null - ); + return TestHelper.makeRequest(client(), "POST", "/_plugins/_ml/models/" + modelId + "/_deploy", null, "", null); } private Response predictRemoteModel(String modelId, String input) throws IOException { - return TestHelper - .makeRequest( - client(), - "POST", - "/_plugins/_ml/models/" + modelId + "/_predict", - null, - input, - null - ); + return TestHelper.makeRequest(client(), "POST", "/_plugins/_ml/models/" + modelId + "/_predict", null, input, null); } private Response undeployRemoteModel(String modelId) throws IOException { String undeployEntity = "{\n" + " \"SYqCMdsFTumUwoHZcsgiUg\": {\n" + " \"stats\": {\n" - + " \"" + modelId + "\": \"undeployed\"\n" + + " \"" + + modelId + + "\": \"undeployed\"\n" + " }\n" + " }\n" + "}"; - return TestHelper - .makeRequest( - client(), - "POST", - "/_plugins/_ml/models/" + modelId + "/_undeploy", - null, - undeployEntity, - null - ); + return TestHelper.makeRequest(client(), "POST", "/_plugins/_ml/models/" + modelId + "/_undeploy", null, undeployEntity, null); } private Map parseResponseToMap(Response response) throws IOException { From 6ce0f7c4e84e829c2974f32985cf6af5af947732 Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Tue, 11 Jul 2023 17:12:24 +0000 Subject: [PATCH 06/15] Refactoring Signed-off-by: Ryan Bogan --- .../ml/rest/RestMLRemoteInferenceIT.java | 45 +++++++------------ 1 file changed, 17 insertions(+), 28 deletions(-) diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java index 64aa903cf0..2639319e0d 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java @@ -59,31 +59,19 @@ public class RestMLRemoteInferenceIT extends MLCommonsRestTestCase { public ExpectedException exceptionRule = ExpectedException.none(); @Before - private void setEncryptionMasterKey() throws IOException { - Response response = TestHelper - .makeRequest( - client(), - "PUT", - "_cluster/settings", - null, - "{\"persistent\":{\"plugins.ml_commons.encryption.master_key\":\"0000000000000011\"}}", - ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "")) - ); - assertEquals(200, response.getStatusLine().getStatusCode()); + public void setup() throws IOException { + disableClusterConnectorAccessControl(); + setEncryptionMasterKey();; } + public void testCreateConnector() throws IOException { - System.out.println(System.getenv()); - disableClusterConnectorAccessControl(); - setEncryptionMasterKey(); Response response = createConnector(completionModelConnectorEntity); Map responseMap = parseResponseToMap(response); assertNotNull((String) responseMap.get("connector_id")); } public void testGetConnector() throws IOException { - disableClusterConnectorAccessControl(); - setEncryptionMasterKey(); Response response = createConnector(completionModelConnectorEntity); Map responseMap = parseResponseToMap(response); String connectorId = (String) responseMap.get("connector_id"); @@ -97,8 +85,6 @@ public void testGetConnector() throws IOException { @Ignore public void testDeleteConnector() throws IOException { - disableClusterConnectorAccessControl(); - setEncryptionMasterKey(); Response response = createConnector(completionModelConnectorEntity); Map responseMap = parseResponseToMap(response); String connectorId = (String) responseMap.get("connector_id"); @@ -108,8 +94,6 @@ public void testDeleteConnector() throws IOException { } public void testSearchConnectors() throws IOException { - disableClusterConnectorAccessControl(); - setEncryptionMasterKey(); createConnector(completionModelConnectorEntity); String searchEntity = "{\n" + " \"query\": {\n" + " \"match_all\": {}\n" + " },\n" + " \"size\": 1000\n" + "}"; Response response = TestHelper @@ -120,8 +104,6 @@ public void testSearchConnectors() throws IOException { } public void testRegisterRemoteModel() throws IOException, InterruptedException { - disableClusterConnectorAccessControl(); - setEncryptionMasterKey(); Response response = createConnector(completionModelConnectorEntity); Map responseMap = parseResponseToMap(response); String connectorId = (String) responseMap.get("connector_id"); @@ -135,8 +117,6 @@ public void testRegisterRemoteModel() throws IOException, InterruptedException { } public void testDeployRemoteModel() throws IOException, InterruptedException { - disableClusterConnectorAccessControl(); - setEncryptionMasterKey(); Response response = createConnector(completionModelConnectorEntity); Map responseMap = parseResponseToMap(response); String connectorId = (String) responseMap.get("connector_id"); @@ -155,8 +135,6 @@ public void testDeployRemoteModel() throws IOException, InterruptedException { } public void testPredictRemoteModel() throws IOException, InterruptedException { - disableClusterConnectorAccessControl(); - setEncryptionMasterKey(); Response response = createConnector(completionModelConnectorEntity); Map responseMap = parseResponseToMap(response); String connectorId = (String) responseMap.get("connector_id"); @@ -185,8 +163,6 @@ public void testPredictRemoteModel() throws IOException, InterruptedException { } public void testUndeployRemoteModel() throws IOException, InterruptedException { - disableClusterConnectorAccessControl(); - setEncryptionMasterKey(); Response response = createConnector(completionModelConnectorEntity); Map responseMap = parseResponseToMap(response); String connectorId = (String) responseMap.get("connector_id"); @@ -680,6 +656,19 @@ private void disableClusterConnectorAccessControl() throws IOException { assertEquals(200, response.getStatusLine().getStatusCode()); } + private void setEncryptionMasterKey() throws IOException { + Response response = TestHelper + .makeRequest( + client(), + "PUT", + "_cluster/settings", + null, + "{\"persistent\":{\"plugins.ml_commons.encryption.master_key\":\"0000000000000011\"}}", + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "")) + ); + assertEquals(200, response.getStatusLine().getStatusCode()); + } + private Response getTask(String taskId) throws IOException { return TestHelper.makeRequest(client(), "GET", "/_plugins/_ml/tasks/" + taskId, null, "", null); } From e2d28f56f85f6fd4c4976d4bf479371dc2cc45c0 Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Tue, 11 Jul 2023 17:29:01 +0000 Subject: [PATCH 07/15] Fix typo Signed-off-by: Ryan Bogan --- .../java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java index 2639319e0d..7b4ac057e3 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java @@ -61,9 +61,8 @@ public class RestMLRemoteInferenceIT extends MLCommonsRestTestCase { @Before public void setup() throws IOException { disableClusterConnectorAccessControl(); - setEncryptionMasterKey();; + setEncryptionMasterKey(); } - public void testCreateConnector() throws IOException { Response response = createConnector(completionModelConnectorEntity); From 88803d106e88b5251c85f1fd5cb19e839544e5e5 Mon Sep 17 00:00:00 2001 From: Peter Zhu Date: Wed, 12 Jul 2023 15:44:13 -0400 Subject: [PATCH 08/15] Remove env vars Signed-off-by: Peter Zhu --- .github/workflows/CI-workflow.yml | 3 --- 1 file changed, 3 deletions(-) diff --git a/.github/workflows/CI-workflow.yml b/.github/workflows/CI-workflow.yml index fd71212c89..d9fc74c080 100644 --- a/.github/workflows/CI-workflow.yml +++ b/.github/workflows/CI-workflow.yml @@ -1,7 +1,4 @@ name: Build and Test ml-commons -env: - OPENAI_KEY: ${{ secrets.OPENAI_KEY }} - COHERE_KEY: ${{ secrets.COHERE_KEY }} # This workflow is triggered on pull requests and push to any branches on: pull_request: From 494821f74189ad1a0e6714ec5b93694519d22307 Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Wed, 12 Jul 2023 20:15:33 +0000 Subject: [PATCH 09/15] Removing unnecessary comments Signed-off-by: Ryan Bogan --- .github/workflows/CI-workflow.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/CI-workflow.yml b/.github/workflows/CI-workflow.yml index fdad02e4e0..212b411197 100644 --- a/.github/workflows/CI-workflow.yml +++ b/.github/workflows/CI-workflow.yml @@ -35,7 +35,6 @@ jobs: role-to-assume: ${{ secrets.ML_ROLE }} aws-region: us-west-2 - # ml-commons - name: Checkout MLCommons uses: actions/checkout@v2 From ca4c32735bf87556a4584e65e573ffcb0e9d8e74 Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Wed, 12 Jul 2023 21:22:21 +0000 Subject: [PATCH 10/15] Rebase with 2.x and remove set up for master key Signed-off-by: Ryan Bogan --- .../ml/rest/RestMLRemoteInferenceIT.java | 18 +++--------------- 1 file changed, 3 insertions(+), 15 deletions(-) diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java index 7b4ac057e3..528e25576f 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java @@ -59,9 +59,9 @@ public class RestMLRemoteInferenceIT extends MLCommonsRestTestCase { public ExpectedException exceptionRule = ExpectedException.none(); @Before - public void setup() throws IOException { + public void setup() throws IOException, InterruptedException { disableClusterConnectorAccessControl(); - setEncryptionMasterKey(); + Thread.sleep(10000); } public void testCreateConnector() throws IOException { @@ -343,6 +343,7 @@ public void testOpenAIModerationsModel() throws IOException, InterruptedExceptio String predictInput = "{\n" + " \"parameters\": {\n" + " \"input\": \"I want to kill them.\"\n" + " }\n" + "}"; response = predictRemoteModel(modelId, predictInput); responseMap = parseResponseToMap(response); + System.out.println(responseMap); List responseList = (List) responseMap.get("inference_results"); responseMap = (Map) responseList.get(0); responseList = (List) responseMap.get("output"); @@ -655,19 +656,6 @@ private void disableClusterConnectorAccessControl() throws IOException { assertEquals(200, response.getStatusLine().getStatusCode()); } - private void setEncryptionMasterKey() throws IOException { - Response response = TestHelper - .makeRequest( - client(), - "PUT", - "_cluster/settings", - null, - "{\"persistent\":{\"plugins.ml_commons.encryption.master_key\":\"0000000000000011\"}}", - ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "")) - ); - assertEquals(200, response.getStatusLine().getStatusCode()); - } - private Response getTask(String taskId) throws IOException { return TestHelper.makeRequest(client(), "GET", "/_plugins/_ml/tasks/" + taskId, null, "", null); } From be539eafbf3208361d3f56b49690f1e0f79944bd Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Wed, 12 Jul 2023 21:25:21 +0000 Subject: [PATCH 11/15] Remove print line Signed-off-by: Ryan Bogan --- .../java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java | 1 - 1 file changed, 1 deletion(-) diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java index 528e25576f..0f897d556e 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java @@ -343,7 +343,6 @@ public void testOpenAIModerationsModel() throws IOException, InterruptedExceptio String predictInput = "{\n" + " \"parameters\": {\n" + " \"input\": \"I want to kill them.\"\n" + " }\n" + "}"; response = predictRemoteModel(modelId, predictInput); responseMap = parseResponseToMap(response); - System.out.println(responseMap); List responseList = (List) responseMap.get("inference_results"); responseMap = (Map) responseList.get(0); responseList = (List) responseMap.get("output"); From 42aac31d1e8a9f7c5e0d63b5888f3017f1e34092 Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Wed, 12 Jul 2023 23:15:36 +0000 Subject: [PATCH 12/15] Add throttling check Signed-off-by: Ryan Bogan --- .../ml/rest/RestMLRemoteInferenceIT.java | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java index 0f897d556e..f1d24c7f38 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java @@ -99,7 +99,6 @@ public void testSearchConnectors() throws IOException { .makeRequest(client(), "GET", "/_plugins/_ml/connectors/_search", null, TestHelper.toHttpEntity(searchEntity), null); Map responseMap = parseResponseToMap(response); assertEquals((Double) 1.0, (Double) ((Map) ((Map) responseMap.get("hits")).get("total")).get("value")); - } public void testRegisterRemoteModel() throws IOException, InterruptedException { @@ -157,6 +156,10 @@ public void testPredictRemoteModel() throws IOException, InterruptedException { responseMap = (Map) responseList.get(0); responseMap = (Map) responseMap.get("dataAsMap"); responseList = (List) responseMap.get("choices"); + if (responseList == null) { + assertTrue(checkThrottlingOpenAI(responseMap)); + return; + } responseMap = (Map) responseList.get(0); assertEquals("\n\nThis is indeed a test", (String) responseMap.get("text")); } @@ -293,6 +296,10 @@ public void testOpenAIEditsModel() throws IOException, InterruptedException { responseMap = (Map) responseList.get(0); responseMap = (Map) responseMap.get("dataAsMap"); responseList = (List) responseMap.get("choices"); + if (responseList == null) { + assertTrue(checkThrottlingOpenAI(responseMap)); + return; + } responseMap = (Map) responseList.get(0); assertNotNull(((String) responseMap.get("text"))); } @@ -349,6 +356,10 @@ public void testOpenAIModerationsModel() throws IOException, InterruptedExceptio responseMap = (Map) responseList.get(0); responseMap = (Map) responseMap.get("dataAsMap"); responseList = (List) responseMap.get("results"); + if (responseList == null) { + assertTrue(checkThrottlingOpenAI(responseMap)); + return; + } responseMap = (Map) responseList.get(0); assertTrue((Boolean) responseMap.get("flagged")); responseMap = (Map) responseMap.get("categories"); @@ -635,6 +646,12 @@ private Response undeployRemoteModel(String modelId) throws IOException { return TestHelper.makeRequest(client(), "POST", "/_plugins/_ml/models/" + modelId + "/_undeploy", null, undeployEntity, null); } + private boolean checkThrottlingOpenAI(Map responseMap) { + Map map = (Map) responseMap.get("error"); + String message = (String) map.get("message"); + return message.equals("You exceeded your current quota, please check your plan and billing details."); + } + private Map parseResponseToMap(Response response) throws IOException { HttpEntity entity = response.getEntity(); assertNotNull(response); From dc027bbe28ea9686267ae5092b5b7733926a32ab Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Wed, 12 Jul 2023 23:34:35 +0000 Subject: [PATCH 13/15] Minor change for assert statement Signed-off-by: Ryan Bogan --- .../java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java index f1d24c7f38..9484370825 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java @@ -161,7 +161,7 @@ public void testPredictRemoteModel() throws IOException, InterruptedException { return; } responseMap = (Map) responseList.get(0); - assertEquals("\n\nThis is indeed a test", (String) responseMap.get("text")); + assertNotNull(responseMap.get("text")); } public void testUndeployRemoteModel() throws IOException, InterruptedException { From 0ca24cfe7b34d9550d37c7039f474b2d0c36d777 Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Wed, 12 Jul 2023 23:48:36 +0000 Subject: [PATCH 14/15] Test changes Signed-off-by: Ryan Bogan --- .../ml/rest/RestMLRemoteInferenceIT.java | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java index 9484370825..b40dc25836 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java @@ -161,7 +161,7 @@ public void testPredictRemoteModel() throws IOException, InterruptedException { return; } responseMap = (Map) responseList.get(0); - assertNotNull(responseMap.get("text")); + assertFalse(((String) responseMap.get("text")).isEmpty()); } public void testUndeployRemoteModel() throws IOException, InterruptedException { @@ -184,7 +184,6 @@ public void testUndeployRemoteModel() throws IOException, InterruptedException { assertTrue(responseMap.toString().contains("undeployed")); } - @Ignore public void testOpenAIChatCompletionModel() throws IOException, InterruptedException { String entity = "{\n" + " \"name\": \"OpenAI chat model Connector\",\n" @@ -237,6 +236,8 @@ public void testOpenAIChatCompletionModel() throws IOException, InterruptedExcep + "}"; response = predictRemoteModel(modelId, predictInput); responseMap = parseResponseToMap(response); + // TODO handle throttling error + assertNotNull(responseMap); } public void testOpenAIEditsModel() throws IOException, InterruptedException { @@ -301,7 +302,7 @@ public void testOpenAIEditsModel() throws IOException, InterruptedException { return; } responseMap = (Map) responseList.get(0); - assertNotNull(((String) responseMap.get("text"))); + assertFalse(((String) responseMap.get("text")).isEmpty()); } public void testOpenAIModerationsModel() throws IOException, InterruptedException { @@ -486,7 +487,6 @@ public void testCohereGenerateTextModel() throws IOException, InterruptedExcepti assertFalse(((String) responseMap.get("text")).isEmpty()); } - @Ignore public void testCohereClassifyModel() throws IOException, InterruptedException { String entity = "{\n" + " \"name\": \"Cohere classify model Connector\",\n" @@ -583,6 +583,13 @@ public void testCohereClassifyModel() throws IOException, InterruptedException { response = predictRemoteModel(modelId, predictInput); responseMap = parseResponseToMap(response); + List responseList = (List) responseMap.get("inference_results"); + responseMap = (Map) responseList.get(0); + responseList = (List) responseMap.get("output"); + responseMap = (Map) responseList.get(0); + responseMap = (Map) responseMap.get("dataAsMap"); + responseList = (List) responseMap.get("classifications"); + assertFalse(responseList.isEmpty()); } private Response createConnector(String input) throws IOException { From d43f3a47cd0e3010707a818886c43e4561a66fb3 Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Thu, 13 Jul 2023 00:06:03 +0000 Subject: [PATCH 15/15] Uncomment delete connector test Signed-off-by: Ryan Bogan --- .../java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java | 1 - 1 file changed, 1 deletion(-) diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java index b40dc25836..707cf88baa 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java @@ -82,7 +82,6 @@ public void testGetConnector() throws IOException { assertEquals("http", (String) responseMap.get("protocol")); } - @Ignore public void testDeleteConnector() throws IOException { Response response = createConnector(completionModelConnectorEntity); Map responseMap = parseResponseToMap(response);