diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java index 4b9c6dfdce..79a4adeab5 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java @@ -67,6 +67,810 @@ public class RestMLRemoteInferenceIT extends MLCommonsRestTestCase { @Rule public ExpectedException exceptionRule = ExpectedException.none(); + public void testCreateConnector() throws IOException { + disableClusterConnectorAccessControl(); + Response response = createConnector(completionModelConnectorEntity); + Map responseMap = parseResponseToMap(response); + assertEquals("CREATED", (String) responseMap.get("status")); + } + + public void testGetConnector() throws IOException { + disableClusterConnectorAccessControl(); + Response response = createConnector(completionModelConnectorEntity); + Map responseMap = parseResponseToMap(response); + String connectorId = (String) responseMap.get("connector_id"); + response = TestHelper + .makeRequest( + client(), + "GET", + "/_plugins/_ml/connectors/" + connectorId, + null, + "", + null + ); + responseMap = parseResponseToMap(response); + assertEquals("OpenAI Connector", (String) responseMap.get("name")); + assertEquals("1", (String) responseMap.get("version")); + assertEquals("The connector to public OpenAI model service for GPT 3.5", (String) responseMap.get("description")); + assertEquals("http/v1", (String) responseMap.get("protocol")); + assertEquals("CREATED", (String) responseMap.get("connector_state")); + } + + public void testDeleteConnector() throws IOException { + disableClusterConnectorAccessControl(); + Response response = createConnector(completionModelConnectorEntity); + Map responseMap = parseResponseToMap(response); + String connectorId = (String) responseMap.get("connector_id"); + response = TestHelper + .makeRequest( + client(), + "DELETE", + "/_plugins/_ml/connectors/" + connectorId, + null, + "", + null + ); + responseMap = parseResponseToMap(response); + assertEquals("deleted", (String) responseMap.get("result")); + } + + public void testSearchConnectors() throws IOException { + disableClusterConnectorAccessControl(); + createConnector(completionModelConnectorEntity); + String searchEntity = "{\n" + + " \"query\": {\n" + + " \"match_all\": {}\n" + + " },\n" + + " \"size\": 1000\n" + + "}"; + Response response = TestHelper + .makeRequest( + client(), + "GET", + "/_plugins/_ml/connectors/_search", + null, + TestHelper.toHttpEntity(searchEntity), + null + ); + Map responseMap = parseResponseToMap(response); + assertEquals((Double) 1.0, (Double) ((Map) ((Map) responseMap.get("hits")).get("total")).get("value")); + + } + + public void testRegisterRemoteModel() throws IOException, InterruptedException { + disableClusterConnectorAccessControl(); + Response response = createConnector(completionModelConnectorEntity); + Map responseMap = parseResponseToMap(response); + String connectorId = (String) responseMap.get("connector_id"); + response = registerRemoteModel("openAI-GPT-3.5 completions", connectorId); + responseMap = parseResponseToMap(response); + String taskId = (String) responseMap.get("task_id"); + waitForTask(taskId, MLTaskState.COMPLETED); + response = getTask(taskId); + responseMap = parseResponseToMap(response); + assertNotNull(responseMap.get("model_id")); + } + + public void testDeployRemoteModel() throws IOException, InterruptedException { + disableClusterConnectorAccessControl(); + Response response = createConnector(completionModelConnectorEntity); + Map responseMap = parseResponseToMap(response); + String connectorId = (String) responseMap.get("connector_id"); + response = registerRemoteModel("openAI-GPT-3.5 completions", connectorId); + responseMap = parseResponseToMap(response); + String taskId = (String) responseMap.get("task_id"); + waitForTask(taskId, MLTaskState.COMPLETED); + response = getTask(taskId); + responseMap = parseResponseToMap(response); + String modelId = (String) responseMap.get("model_id"); + response = deployRemoteModel(modelId); + responseMap = parseResponseToMap(response); + assertEquals("CREATED", (String) responseMap.get("status")); + taskId = (String) responseMap.get("task_id"); + waitForTask(taskId, MLTaskState.COMPLETED); + } + + public void testPredictRemoteModel() throws IOException, InterruptedException { + disableClusterConnectorAccessControl(); + Response response = createConnector(completionModelConnectorEntity); + Map responseMap = parseResponseToMap(response); + String connectorId = (String) responseMap.get("connector_id"); + response = registerRemoteModel("openAI-GPT-3.5 completions", connectorId); + responseMap = parseResponseToMap(response); + String taskId = (String) responseMap.get("task_id"); + waitForTask(taskId, MLTaskState.COMPLETED); + response = getTask(taskId); + responseMap = parseResponseToMap(response); + String modelId = (String) responseMap.get("model_id"); + response = deployRemoteModel(modelId); + responseMap = parseResponseToMap(response); + taskId = (String) responseMap.get("task_id"); + waitForTask(taskId, MLTaskState.COMPLETED); + String predictInput = "{\n" + + " \"parameters\": {\n" + + " \"prompt\": \"Say this is a test\"\n" + + " }\n" + + "}"; + response = predictRemoteModel(modelId, predictInput); + responseMap = parseResponseToMap(response); + List responseList = (List) responseMap.get("inference_results"); + responseMap = (Map) responseList.get(0); + responseList = (List) responseMap.get("output"); + responseMap = (Map) responseList.get(0); + responseMap = (Map) responseMap.get("dataAsMap"); + responseList = (List) responseMap.get("choices"); + responseMap = (Map) responseList.get(0); + assertEquals("\n\nThis is indeed a test", (String) responseMap.get("text")); + } + + public void testUndeployRemoteModel() throws IOException, InterruptedException { + disableClusterConnectorAccessControl(); + Response response = createConnector(completionModelConnectorEntity); + Map responseMap = parseResponseToMap(response); + String connectorId = (String) responseMap.get("connector_id"); + response = registerRemoteModel("openAI-GPT-3.5 completions", connectorId); + responseMap = parseResponseToMap(response); + String taskId = (String) responseMap.get("task_id"); + waitForTask(taskId, MLTaskState.COMPLETED); + response = getTask(taskId); + responseMap = parseResponseToMap(response); + String modelId = (String) responseMap.get("model_id"); + response = deployRemoteModel(modelId); + responseMap = parseResponseToMap(response); + taskId = (String) responseMap.get("task_id"); + waitForTask(taskId, MLTaskState.COMPLETED); + response = undeployRemoteModel(modelId); + responseMap = parseResponseToMap(response); + assertTrue(responseMap.toString().contains("undeployed")); + } + + @Ignore + public void testOpenAICompletionModel() throws IOException, InterruptedException { + String entity = "{\n" + + " \"name\": \"OpenAI Connector\",\n" + + " \"description\": \"The connector to public OpenAI model service for GPT 3.5\",\n" + + " \"version\": 1,\n" + + " \"protocol\": \"http/v1\",\n" + + " \"parameters\": {\n" + + " \"endpoint\": \"api.openai.com\",\n" + + " \"auth\": \"API_Key\",\n" + + " \"content_type\": \"application/json\",\n" + + " \"max_tokens\": 7,\n" + + " \"temperature\": 0,\n" + + " \"model\": \"text-davinci-003\"\n" + + " },\n" + + " \"credential\": {\n" + + " \"openAI_key\": \"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\"\n" + + " },\n" + + " \"actions\": [\n" + + " {\n" + + " \"predict\": {\n" + + " \"method\": \"POST\",\n" + + " \"url\": \"https://api.openai.com/v1/completions\",\n" + + " \"headers\": { \n" + + " \"Authorization\": \"Bearer ${credential.openAI_key}\"\n" + + " },\n" + + " \"request_body\": \"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"prompt\\\": \\\"${parameters.prompt}\\\", \\\"max_tokens\\\": ${parameters.max_tokens}, \\\"temperature\\\": ${parameters.temperature} }\"\n" + + " }\n" + + " },\n" + + " {\n" + + " \"metadata\": {\n" + + " \"method\": \"GET\",\n" + + " \"url\": \"https://${parameters.endpoint}/v1/models/{model}\",\n" + + " \"headers\": {\n" + + " \"Authorization\": \"Bearer ${credential.openAI_key}\"\n" + + " }\n" + + " }\n" + + " }\n" + + " ],\n" + + " \"add_all_backend_roles\": true\n" + + "}"; + Response response = createConnector(entity); + Map responseMap = parseResponseToMap(response); + String connectorId = (String) responseMap.get("connector_id"); + response = registerRemoteModel("openAI-GPT-3.5 completions", connectorId); + responseMap = parseResponseToMap(response); + String taskId = (String) responseMap.get("task_id"); + waitForTask(taskId, MLTaskState.COMPLETED); + response = getTask(taskId); + responseMap = parseResponseToMap(response); + String modelId = (String) responseMap.get("model_id"); + response = deployRemoteModel(modelId); + responseMap = parseResponseToMap(response); + taskId = (String) responseMap.get("task_id"); + waitForTask(taskId, MLTaskState.COMPLETED); + String predictInput = "{\n" + + " \"parameters\": {\n" + + " \"prompt\": \"Say this is a test\"\n" + + " }\n" + + "}"; + response = predictRemoteModel(modelId, predictInput); + responseMap = parseResponseToMap(response); + List responseList = (List) responseMap.get("inference_results"); + responseMap = (Map) responseList.get(0); + responseList = (List) responseMap.get("output"); + responseMap = (Map) responseList.get(0); + responseMap = (Map) responseMap.get("dataAsMap"); + responseList = (List) responseMap.get("choices"); + responseMap = (Map) responseList.get(0); + assertEquals("\n\nThis is indeed a test", (String) responseMap.get("text")); + } + + @Ignore + public void testOpenAIChatCompletionModel() throws IOException, InterruptedException { + String entity = "{\n" + + " \"name\": \"OpenAI chat model Connector\",\n" + + " \"description\": \"The connector to public OpenAI model service for GPT 3.5\",\n" + + " \"version\": 1,\n" + + " \"protocol\": \"http/v1\",\n" + + " \"parameters\": {\n" + + " \"endpoint\": \"api.openai.com\",\n" + + " \"auth\": \"API_Key\",\n" + + " \"content_type\": \"application/json\",\n" + + " \"max_tokens\": 7,\n" + + " \"temperature\": 0,\n" + + " \"model\": \"gpt-3.5-turbo\"\n" + + " },\n" + + " \"credential\": {\n" + + " \"openAI_key\": \"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\"\n" + + " },\n" + + " \"actions\": [\n" + + " {\n" + + " \"predict\": {\n" + + " \"method\": \"POST\",\n" + + " \"url\": \"https://api.openai.com/v1/chat/completions\",\n" + + " \"headers\": { \n" + + " \"Authorization\": \"Bearer ${credential.openAI_key}\"\n" + + " },\n" + + " \"request_body\": \"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"messages\\\": ${parameters.messages} }\"\n" + + " }\n" + + " },\n" + + " {\n" + + " \"metadata\": {\n" + + " \"method\": \"GET\",\n" + + " \"url\": \"https://${parameters.endpoint}/v1/models/{model}\",\n" + + " \"headers\": {\n" + + " \"Authorization\": \"Bearer ${credential.openAI_key}\"\n" + + " }\n" + + " }\n" + + " }\n" + + " ]\n" + + "}"; + Response response = createConnector(entity); + Map responseMap = parseResponseToMap(response); + String connectorId = (String) responseMap.get("connector_id"); + response = registerRemoteModel("openAI-GPT-3.5 chat model", connectorId); + responseMap = parseResponseToMap(response); + String taskId = (String) responseMap.get("task_id"); + waitForTask(taskId, MLTaskState.COMPLETED); + response = getTask(taskId); + responseMap = parseResponseToMap(response); + String modelId = (String) responseMap.get("model_id"); + response = deployRemoteModel(modelId); + responseMap = parseResponseToMap(response); + taskId = (String) responseMap.get("task_id"); + waitForTask(taskId, MLTaskState.COMPLETED); + String predictInput = "{\n" + + " \"parameters\": {\n" + + " \"messages\": [{\"role\": \"user\", \"content\": \"Hello!\"}]\n" + + " }\n" + + "}"; + response = predictRemoteModel(modelId, predictInput); + responseMap = parseResponseToMap(response); + System.out.println(responseMap); + } + + public void testOpenAIEditsModel() throws IOException, InterruptedException { + String entity = "{\n" + + " \"name\": \"OpenAI Edit model Connector\",\n" + + " \"description\": \"The connector to public OpenAI edit model service\",\n" + + " \"version\": 1,\n" + + " \"protocol\": \"http/v1\",\n" + + " \"parameters\": {\n" + + " \"endpoint\": \"api.openai.com\",\n" + + " \"auth\": \"API_Key\",\n" + + " \"content_type\": \"application/json\",\n" + + " \"model\": \"text-davinci-edit-001\"\n" + + " },\n" + . + ..... + + " \"credential\": {\n" + + " \"openAI_key\": \"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\"\n" + + " },\n" + + " \"actions\": [\n" + + " {\n" + + " \"predict\": {\n" + + " \"method\": \"POST\",\n" + + " \"url\": \"https://api.openai.com/v1/edits\",\n" + + " \"headers\": { \n" + + " \"Authorization\": \"Bearer ${credential.openAI_key}\"\n" + + " },\n" + + " \"request_body\": \"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"input\\\": \\\"${parameters.input}\\\", \\\"instruction\\\": \\\"${parameters.instruction}\\\" }\"\n" + + " }\n" + + " },\n" + + " {\n" + + " \"metadata\": {\n" + + " \"method\": \"GET\",\n" + + " \"url\": \"https://${parameters.endpoint}/v1/models/{model}\",\n" + + " \"headers\": {\n" + + " \"Authorization\": \"Bearer ${credential.openAI_key}\"\n" + + " }\n" + + " }\n" + + " }\n" + + " ]\n" + + "}"; + Response response = createConnector(entity); + Map responseMap = parseResponseToMap(response); + String connectorId = (String) responseMap.get("connector_id"); + response = registerRemoteModel("openAI-GPT-3.5 edit model", connectorId); + responseMap = parseResponseToMap(response); + String taskId = (String) responseMap.get("task_id"); + waitForTask(taskId, MLTaskState.COMPLETED); + response = getTask(taskId); + responseMap = parseResponseToMap(response); + String modelId = (String) responseMap.get("model_id"); + response = deployRemoteModel(modelId); + responseMap = parseResponseToMap(response); + taskId = (String) responseMap.get("task_id"); + waitForTask(taskId, MLTaskState.COMPLETED); + String predictInput = "{\n" + + " \"parameters\": {\n" + + " \"input\": \"What day of the wek is it?\",\n" + + " \"instruction\": \"Fix the spelling mistakes\"\n" + + " }\n" + + "}"; + response = predictRemoteModel(modelId, predictInput); + responseMap = parseResponseToMap(response); + System.out.println(responseMap); + List responseList = (List) responseMap.get("inference_results"); + responseMap = (Map) responseList.get(0); + responseList = (List) responseMap.get("output"); + responseMap = (Map) responseList.get(0); + responseMap = (Map) responseMap.get("dataAsMap"); + responseList = (List) responseMap.get("choices"); + responseMap = (Map) responseList.get(0); + assertTrue(((String) responseMap.get("text")).contains("What day of the week is it?")); + } + + public void testOpenAIModerationsModel() throws IOException, InterruptedException { + String entity = "{\n" + + " \"name\": \"OpenAI moderations model Connector\",\n" + + " \"description\": \"The connector to public OpenAI moderations model service\",\n" + + " \"version\": 1,\n" + + " \"protocol\": \"http/v1\",\n" + + " \"parameters\": {\n" + + " \"endpoint\": \"api.openai.com\",\n" + + " \"auth\": \"API_Key\",\n" + + " \"content_type\": \"application/json\",\n" + + " \"model\": \"moderations\"\n" + + " },\n" + + " \"credential\": {\n" + + " \"openAI_key\": \"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\"\n" + + " },\n" + + " \"actions\": [\n" + + " {\n" + + " \"predict\": {\n" + + " \"method\": \"POST\",\n" + + " \"url\": \"https://api.openai.com/v1/moderations\",\n" + + " \"headers\": { \n" + + " \"Authorization\": \"Bearer ${credential.openAI_key}\"\n" + + " },\n" + + " \"request_body\": \"{ \\\"input\\\": \\\"${parameters.input}\\\" }\"\n" + + " }\n" + + " },\n" + + " {\n" + + " \"metadata\": {\n" + + " \"method\": \"GET\",\n" + + " \"url\": \"https://${parameters.endpoint}/v1/models/{model}\",\n" + + " \"headers\": {\n" + + " \"Authorization\": \"Bearer ${credential.openAI_key}\"\n" + + " }\n" + + " }\n" + + " }\n" + + " ]\n" + + "}"; + Response response = createConnector(entity); + Map responseMap = parseResponseToMap(response); + String connectorId = (String) responseMap.get("connector_id"); + response = registerRemoteModel("openAI-GPT-3.5 moderations model", connectorId); + responseMap = parseResponseToMap(response); + String taskId = (String) responseMap.get("task_id"); + waitForTask(taskId, MLTaskState.COMPLETED); + response = getTask(taskId); + responseMap = parseResponseToMap(response); + String modelId = (String) responseMap.get("model_id"); + response = deployRemoteModel(modelId); + responseMap = parseResponseToMap(response); + taskId = (String) responseMap.get("task_id"); + waitForTask(taskId, MLTaskState.COMPLETED); + String predictInput = "{\n" + + " \"parameters\": {\n" + + " \"input\": \"I want to kill them.\"\n" + + " }\n" + + "}"; + response = predictRemoteModel(modelId, predictInput); + responseMap = parseResponseToMap(response); + List responseList = (List) responseMap.get("inference_results"); + responseMap = (Map) responseList.get(0); + responseList = (List) responseMap.get("output"); + responseMap = (Map) responseList.get(0); + responseMap = (Map) responseMap.get("dataAsMap"); + responseList = (List) responseMap.get("results"); + responseMap = (Map) responseList.get(0); + assertTrue((Boolean) responseMap.get("flagged")); + responseMap = (Map) responseMap.get("categories"); + assertTrue((Boolean) responseMap.get("violence")); + } + + public void testOpenAITextEmbeddingModel() throws IOException, InterruptedException { + String entity = "{\n" + + " \"name\": \"OpenAI text embedding model Connector\",\n" + + " \"description\": \"The connector to public OpenAI text embedding model service\",\n" + + " \"version\": 1,\n" + + " \"protocol\": \"http/v1\",\n" + + " \"parameters\": {\n" + + " \"endpoint\": \"api.openai.com\",\n" + + " \"auth\": \"API_Key\",\n" + + " \"content_type\": \"application/json\",\n" + + " \"model\": \"text-embedding-ada-002\"\n" + + " },\n" + + " \"credential\": {\n" + + " \"openAI_key\": \"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\"\n" + + " },\n" + + " \"actions\": [\n" + + " {\n" + + " \"predict\": {\n" + + " \"method\": \"POST\",\n" + + " \"url\": \"https://api.openai.com/v1/embeddings\",\n" + + " \"headers\": { \n" + + " \"Authorization\": \"Bearer ${credential.openAI_key}\"\n" + + " },\n" + + " \"request_body\": \"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"input\\\": \\\"${parameters.input}\\\" }\",\n" + + " \"pre_process_function\": \"text_docs_to_openai_embedding_input\",\n" + + " \"post_process_function\": \"openai_embedding\"\n" + + " }\n" + + " },\n" + + " {\n" + + " \"metadata\": {\n" + + " \"method\": \"GET\",\n" + + " \"url\": \"https://${parameters.endpoint}/v1/models/{model}\",\n" + + " \"headers\": {\n" + + " \"Authorization\": \"Bearer ${credential.openAI_key}\"\n" + + " }\n" + + " }\n" + + " }\n" + + " ]\n" + + "}"; + Response response = createConnector(entity); + Map responseMap = parseResponseToMap(response); + String connectorId = (String) responseMap.get("connector_id"); + response = registerRemoteModel("openAI text embedding model", connectorId); + responseMap = parseResponseToMap(response); + String taskId = (String) responseMap.get("task_id"); + waitForTask(taskId, MLTaskState.COMPLETED); + response = getTask(taskId); + responseMap = parseResponseToMap(response); + String modelId = (String) responseMap.get("model_id"); + response = deployRemoteModel(modelId); + responseMap = parseResponseToMap(response); + taskId = (String) responseMap.get("task_id"); + waitForTask(taskId, MLTaskState.COMPLETED); + String predictInput = "{\n" + + " \"parameters\": {\n" + + " \"input\": \"The food was delicious\"\n" + + " }\n" + + "}"; + response = predictRemoteModel(modelId, predictInput); + responseMap = parseResponseToMap(response); + List responseList = (List) responseMap.get("inference_results"); + responseMap = (Map) responseList.get(0); + responseList = (List) responseMap.get("output"); + responseMap = (Map) responseList.get(0); + responseMap = (Map) responseMap.get("dataAsMap"); + responseList = (List) responseMap.get("data"); + responseMap = (Map) responseList.get(0); + assertFalse(((List) responseMap.get("embedding")).isEmpty()); + } + + public void testCohereGenerateTextModel() throws IOException, InterruptedException { + String entity = "{\n" + + " \"name\": \"Cohere generate text model Connector\",\n" + + " \"description\": \"The connector to public Cohere generate text model service\",\n" + + " \"version\": 1,\n" + + " \"protocol\": \"http/v1\",\n" + + " \"parameters\": {\n" + + " \"endpoint\": \"api.cohere.ai\",\n" + + " \"auth\": \"API_Key\",\n" + + " \"content_type\": \"application/json\",\n" + + " \"max_tokens\": \"20\"\n" + + " },\n" + + " \"credential\": {\n" + + " \"cohere_key\": \"aj1st8vYo6fjJVwOg8DRuqMbcJhlcOZEyvjDUQ4y\"\n" + + " },\n" + + " \"actions\": [\n" + + " {\n" + + " \"predict\": {\n" + + " \"method\": \"POST\",\n" + + " \"url\": \"https://${parameters.endpoint}/v1/generate\",\n" + + " \"headers\": { \n" + + " \"Authorization\": \"Bearer ${credential.cohere_key}\"\n" + + " },\n" + + " \"request_body\": \"{ \\\"max_tokens\\\": ${parameters.max_tokens}, \\\"return_likelihoods\\\": \\\"NONE\\\", \\\"truncate\\\": \\\"END\\\", \\\"prompt\\\": \\\"${parameters.prompt}\\\" }\"\n" + + " }\n" + + " }\n" + + " ]\n" + + "}"; + Response response = createConnector(entity); + Map responseMap = parseResponseToMap(response); + String connectorId = (String) responseMap.get("connector_id"); + response = registerRemoteModel("cohere generate text model", connectorId); + responseMap = parseResponseToMap(response); + String taskId = (String) responseMap.get("task_id"); + waitForTask(taskId, MLTaskState.COMPLETED); + response = getTask(taskId); + responseMap = parseResponseToMap(response); + String modelId = (String) responseMap.get("model_id"); + response = deployRemoteModel(modelId); + responseMap = parseResponseToMap(response); + taskId = (String) responseMap.get("task_id"); + waitForTask(taskId, MLTaskState.COMPLETED); + String predictInput = "{\n" + + " \"parameters\": {\n" + + " \"prompt\": \"Once upon a time in a magical land called\",\n" + + " \"max_tokens\": 40\n" + + " }\n" + + "}"; + response = predictRemoteModel(modelId, predictInput); + responseMap = parseResponseToMap(response); + List responseList = (List) responseMap.get("inference_results"); + responseMap = (Map) responseList.get(0); + responseList = (List) responseMap.get("output"); + responseMap = (Map) responseList.get(0); + responseMap = (Map) responseMap.get("dataAsMap"); + responseList = (List) responseMap.get("generations"); + responseMap = (Map) responseList.get(0); + assertFalse(((String) responseMap.get("text")).isEmpty()); + } + + @Ignore + public void testCohereClassifyModel() throws IOException, InterruptedException { + String entity = "{\n" + + " \"name\": \"Cohere classify model Connector\",\n" + + " \"description\": \"The connector to public Cohere classify model service\",\n" + + " \"version\": 1,\n" + + " \"protocol\": \"http/v1\",\n" + + " \"parameters\": {\n" + + " \"endpoint\": \"api.cohere.ai\",\n" + + " \"auth\": \"API_Key\",\n" + + " \"content_type\": \"application/json\",\n" + + " \"max_tokens\": \"20\"\n" + + " },\n" + + " \"credential\": {\n" + + " \"cohere_key\": \"aj1st8vYo6fjJVwOg8DRuqMbcJhlcOZEyvjDUQ4y\"\n" + + " },\n" + + " \"actions\": [\n" + + " {\n" + + " \"predict\": {\n" + + " \"method\": \"POST\",\n" + + " \"url\": \"https://${parameters.endpoint}/v1/classify\",\n" + + " \"headers\": { \n" + + " \"Authorization\": \"Bearer ${credential.cohere_key}\"\n" + + " },\n" + + " \"request_body\": \"{ \\\"inputs\\\": ${parameters.inputs}, \\\"examples\\\": ${parameters.examples}, \\\"truncate\\\": \\\"END\\\" }\"\n" + + " }\n" + + " }\n" + + " ]\n" + + "}"; + Response response = createConnector(entity); + Map responseMap = parseResponseToMap(response); + String connectorId = (String) responseMap.get("connector_id"); + response = registerRemoteModel("cohere classify model", connectorId); + responseMap = parseResponseToMap(response); + String taskId = (String) responseMap.get("task_id"); + waitForTask(taskId, MLTaskState.COMPLETED); + response = getTask(taskId); + responseMap = parseResponseToMap(response); + String modelId = (String) responseMap.get("model_id"); + response = deployRemoteModel(modelId); + responseMap = parseResponseToMap(response); + taskId = (String) responseMap.get("task_id"); + waitForTask(taskId, MLTaskState.COMPLETED); + String predictInput = "{\n" + + " \"parameters\": {\n" + + " \"inputs\": [\n" + + " \"Confirm your email address\",\n" + + " \"hey i need u to send some $\"\n" + + " ],\n" + + " \"examples\": [\n" + + " {\n" + + " \"text\": \"Dermatologists don't like her!\",\n" + + " \"label\": \"Spam\"\n" + + " },\n" + + " {\n" + + " \"text\": \"Hello, open to this?\",\n" + + " \"label\": \"Spam\"\n" + + " },\n" + + " {\n" + + " \"text\": \"I need help please wire me $1000 right now\",\n" + + " \"label\": \"Spam\"\n" + + " },\n" + + " {\n" + + " \"text\": \"Nice to know you ;)\",\n" + + " \"label\": \"Spam\"\n" + + " },\n" + + " {\n" + + " \"text\": \"Please help me?\",\n" + + " \"label\": \"Spam\"\n" + + " },\n" + + " {\n" + + " \"text\": \"Your parcel will be delivered today\",\n" + + " \"label\": \"Not spam\"\n" + + " },\n" + + " {\n" + + " \"text\": \"Review changes to our Terms and Conditions\",\n" + + " \"label\": \"Not spam\"\n" + + " },\n" + + " {\n" + + " \"text\": \"Weekly sync notes\",\n" + + " \"label\": \"Not spam\"\n" + + " },\n" + + " {\n" + + " \"text\": \"Re: Follow up from todays meeting\",\n" + + " \"label\": \"Not spam\"\n" + + " },\n" + + " {\n" + + " \"text\": \"Pre-read for tomorrow\",\n" + + " \"label\": \"Not spam\"\n" + + " }\n" + + " ]\n" + + " }\n" + + "}"; + + System.out.println(StringUtils.isJson(" { \"inputs\": [\"Confirm your email address\",\"hey i need u to send some $\"]}")); + response = predictRemoteModel(modelId, predictInput); + responseMap = parseResponseToMap(response); + System.out.println(responseMap); + } + + public void testSageMakerModel() throws IOException, InterruptedException { + String entity = "{\n" + + " \"name\": \"SageMaker huggingface test embedding Connector\",\n" + + " \"description\": \"The connector to sageMaker service for a text embedding model with huggingface\",\n" + + " \"version\": 1,\n" + + " \"protocol\": \"aws/v1\",\n" + + " \"parameters\": {\n" + + " \"endpoint\": \"runtime.sagemaker.us-west-2.amazonaws.com\",\n" + + " \"auth\": \"Sig_V4\",\n" + + " \"region\": \"us-west-2\",\n" + + " \"service_name\": \"sagemaker\",\n" + + " \"content_type\": \"application/json\",\n" + + " \"connector.pre_process_function\": \"\\n StringBuilder builder = new StringBuilder();\\n builder.append(\\\"\\\\\\\"\\\");\\n builder.append(params.text_docs[0]);\\n builder.append(\\\"\\\\\\\"\\\");\\n def parameters = \\\"{\\\" +\\\"\\\\\\\"inputs\\\\\\\":\\\" + builder + \\\"}\\\";\\n return \\\"{\\\" +\\\"\\\\\\\"parameters\\\\\\\":\\\" + parameters + \\\"}\\\";\",\n" + + " \"connector.post_process_function\": \"\\n def name = \\\"sentence_embedding\\\";\\n def dataType = \\\"FLOAT32\\\";\\n if (params.vectors == null || params.vectors.length == 0) {\\n return null;\\n }\\n def shape = [params.vectors.length];\\n def json = \\\"{\\\" +\\n \\\"\\\\\\\"name\\\\\\\":\\\\\\\"\\\" + name + \\\"\\\\\\\",\\\" +\\n \\\"\\\\\\\"data_type\\\\\\\":\\\\\\\"\\\" + dataType + \\\"\\\\\\\",\\\" +\\n \\\"\\\\\\\"shape\\\\\\\":\\\" + shape + \\\",\\\" +\\n \\\"\\\\\\\"data\\\\\\\":\\\" + params.vectors +\\n \\\"}\\\";\\n return json;\\n \"\n" + + " },\n" + + " \"credential\": {\n" + + " \"access_key\": \"AKIAWDGYV7A7FR5ROFVU\",\n" + + " \"secret_key\": \"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\"\n" + + " },\n" + + " \"actions\": [\n" + + " {\n" + + " \"predict\": {\n" + + " \"method\": \"POST\",\n" + + " \"url\": \"https://runtime.sagemaker.us-west-2.amazonaws.com/endpoints/huggingface-pytorch-inference-2023-06-17-06-40-44-801/invocations\",\n" + + " \"headers\": { \n" + + " \"content-type\": \"application/json\"\n" + + " },\n" + + " \"request_body\": \"{ \\\"inputs\\\": \\\"${parameters.inputs}\\\" }\"\n" + + " }\n" + + " },\n" + + " {\n" + + " \"metadata\": {\n" + + " \"method\": \"GET\",\n" + + " \"url\": \"https://runtime.sagemaker.us-west-2.amazonaws.com/endpoints/huggingface-pytorch-inference-2023-06-17-06-40-44-801/invocations\",\n" + + " \"headers\": {\n" + + " \"content-type\": \"application/json\"\n" + + " }\n" + + " }\n" + + " }\n" + + " ]\n" + + "}"; + Response response = createConnector(entity); + Map responseMap = parseResponseToMap(response); + String connectorId = (String) responseMap.get("connector_id"); + response = registerRemoteModel("sagemaker hugging face model", connectorId); + responseMap = parseResponseToMap(response); + String taskId = (String) responseMap.get("task_id"); + waitForTask(taskId, MLTaskState.COMPLETED); + response = getTask(taskId); + responseMap = parseResponseToMap(response); + String modelId = (String) responseMap.get("model_id"); + response = deployRemoteModel(modelId); + responseMap = parseResponseToMap(response); + taskId = (String) responseMap.get("task_id"); + waitForTask(taskId, MLTaskState.COMPLETED); + String predictInput = "{\n" + + " \"parameters\": {\n" + + " \"input\": \"The food was delicious\"\n" + + " }\n" + + "}"; + response = predictRemoteModel(modelId, predictInput); + responseMap = parseResponseToMap(response); + List responseList = (List) responseMap.get("inference_results"); + responseMap = (Map) responseList.get(0); + responseList = (List) responseMap.get("output"); + responseMap = (Map) responseList.get(0); + assertFalse(((List) responseMap.get("data")).isEmpty()); + } + + public void testBedRockTextEmbeddingModel() throws IOException, InterruptedException { + String entity = "{\n" + + " \"name\": \"BedRock test embedding Connector\",\n" + + " \"description\": \"The connector to BedRock service for a text embedding model\",\n" + + " \"version\": 1,\n" + + " \"protocol\": \"aws/v1\",\n" + + " \"parameters\": {\n" + + " \"endpoint\": \"bedrock.us-east-1.amazonaws.com\",\n" + + " \"auth\": \"Sig_V4\",\n" + + " \"region\": \"us-east-1\",\n" + + " \"service_name\": \"bedrock\",\n" + + " \"content_type\": \"application/json\"\n" + + " },\n" + + " \"credential\": {\n" + + " \"access_key\": \"AKIAWDGYV7A7KELTJIV3\",\n" + + " \"secret_key\": \"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\"\n" + + " },\n" + + " \"actions\": [\n" + + " {\n" + + " \"predict\": {\n" + + " \"method\": \"POST\",\n" + + " \"url\": \"https://bedrock.us-east-1.amazonaws.com/model/amazon.titan-e1t-medium/invoke\",\n" + + " \"headers\": { \n" + + " \"content-type\": \"application/json\",\n" + + " \"x-amz-content-sha256\": \"required\"\n" + + " },\n" + + " \"request_body\": \"{ \\\"inputText\\\": \\\"${parameters.inputText}\\\" }\"\n" + + " }\n" + + " },\n" + + " {\n" + + " \"metadata\": {\n" + + " \"method\": \"GET\",\n" + + " \"url\": \"https://bedrock.us-east-1.amazonaws.com/model/amazon.titan-e1t-medium/ping\",\n" + + " \"headers\": {\n" + + " \"content-type\": \"application/json\"\n" + + " }\n" + + " }\n" + + " }\n" + + " ]\n" + + "}"; + Response response = createConnector(entity); + Map responseMap = parseResponseToMap(response); + String connectorId = (String) responseMap.get("connector_id"); + response = registerRemoteModel("bedrock embedding model", connectorId); + responseMap = parseResponseToMap(response); + String taskId = (String) responseMap.get("task_id"); + waitForTask(taskId, MLTaskState.COMPLETED); + response = getTask(taskId); + responseMap = parseResponseToMap(response); + String modelId = (String) responseMap.get("model_id"); + response = deployRemoteModel(modelId); + responseMap = parseResponseToMap(response); + taskId = (String) responseMap.get("task_id"); + waitForTask(taskId, MLTaskState.COMPLETED); + String predictInput = "{\n" + + " \"parameters\": {\n" + + " \"input\": \"The food was delicious\"\n" + + " }\n" + + "}"; + response = predictRemoteModel(modelId, predictInput); + responseMap = parseResponseToMap(response); + System.out.println(responseMap); + List responseList = (List) responseMap.get("inference_results"); + responseMap = (Map) responseList.get(0); + responseList = (List) responseMap.get("output"); + responseMap = (Map) responseList.get(0); + responseMap = (Map) responseMap.get("dataAsMap"); + assertFalse(((List) responseMap.get("embedding")).isEmpty()); + } + private Response createConnector(String input) throws IOException { return TestHelper .makeRequest(