From 343ae16371f86039bda2b92673874804d84495ef Mon Sep 17 00:00:00 2001 From: Sicheng Song Date: Fri, 2 Feb 2024 11:23:26 -0800 Subject: [PATCH] Fix internal connector (#1989) * Fix internal connector Signed-off-by: Sicheng Song * spotless fix Signed-off-by: Sicheng Song --------- Signed-off-by: Sicheng Song --- .../transport/model/MLUpdateModelInput.java | 38 ++++++++++++++++++ .../model/MLUpdateModelInputTest.java | 39 +++++++++++++++++++ .../models/UpdateModelTransportAction.java | 14 +++---- .../UpdateModelTransportActionTests.java | 8 +++- 4 files changed, 90 insertions(+), 9 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelInput.java b/common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelInput.java index 74090c3491..065dff69e0 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelInput.java @@ -140,6 +140,44 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws return builder; } + public XContentBuilder toXContentForUpdateRequestDoc(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(MODEL_ID_FIELD, modelId); + if (name != null) { + builder.field(MODEL_NAME_FIELD, name); + } + if (description != null) { + builder.field(DESCRIPTION_FIELD, description); + } + if (version != null) { + builder.field(MODEL_VERSION_FIELD, version); + } + if (modelGroupId != null) { + builder.field(MODEL_GROUP_ID_FIELD, modelGroupId); + } + if (isEnabled != null) { + builder.field(IS_ENABLED_FIELD, isEnabled); + } + if (rateLimiter != null) { + builder.field(RATE_LIMITER_FIELD, rateLimiter); + } + if (modelConfig != null) { + builder.field(MODEL_CONFIG_FIELD, modelConfig); + } + // Notice that we serialize the updatedConnector to the connector field, in order to be compatible with original internal connector field format. + if (updatedConnector != null) { + builder.field(CONNECTOR_FIELD, updatedConnector); + } + if (connectorId != null) { + builder.field(CONNECTOR_ID_FIELD, connectorId); + } + if (lastUpdateTime != null) { + builder.field(LAST_UPDATED_TIME_FIELD, lastUpdateTime.toEpochMilli()); + } + builder.endObject(); + return builder; + } + @Override public void writeTo(StreamOutput out) throws IOException { out.writeString(modelId); diff --git a/common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelInputTest.java b/common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelInputTest.java index a53f1ee02d..6f4018165f 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelInputTest.java @@ -61,6 +61,23 @@ public class MLUpdateModelInputTest { + "\"test-connector_id\",\"connector\":{\"description\":\"updated description\",\"version\":\"1\"},\"last_updated_time\":1}"; + private final String expectedOutputStrForUpdateRequestDoc = "{\"model_id\":\"test-model_id\",\"name\":\"name\",\"description\":\"description\",\"model_version\":" + + + "\"2\",\"model_group_id\":\"modelGroupId\",\"is_enabled\":false,\"rate_limiter\":" + + "{\"limit\":\"1\",\"unit\":\"MILLISECONDS\"},\"model_config\":" + + "{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"" + + + "{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"},\"connector\":" + + "{\"name\":\"test\",\"version\":\"1\",\"protocol\":\"http\",\"parameters\":{\"param1\":\"value1\"},\"credential\":" + + + "{\"api_key\":\"credential_value\"},\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":" + + + "\"https://api.openai.com/v1/chat/completions\",\"headers\":{\"Authorization\":\"Bearer ${credential.api_key}\"},\"request_body\":" + + + "\"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"messages\\\": ${parameters.messages} }\"}]},\"connector_id\":" + + + "\"test-connector_id\",\"last_updated_time\":1}"; + private final String expectedOutputStr = "{\"model_id\":null,\"name\":\"name\",\"description\":\"description\",\"model_group_id\":" + "\"modelGroupId\",\"is_enabled\":false,\"rate_limiter\":" + @@ -153,6 +170,21 @@ public void testToXContent() throws Exception { assertEquals(expectedInputStr, jsonStr); } + @Test + public void testToXContentForUpdateRequestDoc() throws Exception { + String jsonStr = serializationWithToXContentForUpdateRequestDoc(updateModelInput); + assertEquals(expectedOutputStrForUpdateRequestDoc, jsonStr); + } + + @Test + public void testToXContenttForUpdateRequestDocIncomplete() throws Exception { + String expectedIncompleteInputStr = "{\"model_id\":\"test-model_id\"}"; + updateModelInput = MLUpdateModelInput.builder() + .modelId("test-model_id").build(); + String jsonStr = serializationWithToXContentForUpdateRequestDoc(updateModelInput); + assertEquals(expectedIncompleteInputStr, jsonStr); + } + @Test public void testToXContentIncomplete() throws Exception { String expectedIncompleteInputStr = "{\"model_id\":\"test-model_id\"}"; @@ -237,4 +269,11 @@ private String serializationWithToXContent(MLUpdateModelInput input) throws IOEx assertNotNull(builder); return builder.toString(); } + + private String serializationWithToXContentForUpdateRequestDoc(MLUpdateModelInput input) throws IOException { + XContentBuilder builder = XContentFactory.jsonBuilder(); + input.toXContentForUpdateRequestDoc(builder, ToXContent.EMPTY_PARAMS); + assertNotNull(builder); + return builder.toString(); + } } diff --git a/plugin/src/main/java/org/opensearch/ml/action/models/UpdateModelTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/models/UpdateModelTransportAction.java index d520977715..57e0361fae 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/models/UpdateModelTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/models/UpdateModelTransportAction.java @@ -200,7 +200,7 @@ private void updateRemoteOrTextEmbeddingModel( MLModel mlModel, User user, ActionListener wrappedListener - ) { + ) throws IOException { String newModelGroupId = (Strings.hasLength(updateModelInput.getModelGroupId()) && !Objects.equals(updateModelInput.getModelGroupId(), mlModel.getModelGroupId())) ? updateModelInput.getModelGroupId() : null; String newConnectorId = Strings.hasLength(updateModelInput.getConnectorId()) ? updateModelInput.getConnectorId() : null; @@ -330,7 +330,7 @@ private void updateModelWithRegisteringToAnotherModelGroup( .validateModelGroupAccess(user, newModelGroupId, client, ActionListener.wrap(hasNewModelGroupPermission -> { if (hasNewModelGroupPermission) { mlModelGroupManager.getModelGroupResponse(newModelGroupId, ActionListener.wrap(newModelGroupResponse -> { - updateRequestConstructor( + buildUpdateRequest( modelId, newModelGroupId, updateRequest, @@ -364,11 +364,11 @@ private void updateModelWithRegisteringToAnotherModelGroup( wrappedListener.onFailure(exception); })); } else { - updateRequestConstructor(modelId, updateRequest, updateModelInput, wrappedListener, isUpdateModelCache); + buildUpdateRequest(modelId, updateRequest, updateModelInput, wrappedListener, isUpdateModelCache); } } - private void updateRequestConstructor( + private void buildUpdateRequest( String modelId, UpdateRequest updateRequest, MLUpdateModelInput updateModelInput, @@ -377,7 +377,7 @@ private void updateRequestConstructor( ) { try { updateModelInput.setLastUpdateTime(Instant.now()); - updateRequest.doc(updateModelInput.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS)); + updateRequest.doc(updateModelInput.toXContentForUpdateRequestDoc(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS)); updateRequest.docAsUpsert(true); updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); if (isUpdateModelCache) { @@ -397,7 +397,7 @@ private void updateRequestConstructor( } } - private void updateRequestConstructor( + private void buildUpdateRequest( String modelId, String newModelGroupId, UpdateRequest updateRequest, @@ -418,7 +418,7 @@ private void updateRequestConstructor( Integer.parseInt(updatedVersion) ); try { - updateRequest.doc(updateModelInput.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS)); + updateRequest.doc(updateModelInput.toXContentForUpdateRequestDoc(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS)); updateRequest.docAsUpsert(true); updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); if (isUpdateModelCache) { diff --git a/plugin/src/test/java/org/opensearch/ml/action/models/UpdateModelTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/models/UpdateModelTransportActionTests.java index c24d4bf816..b6ddfe8cb4 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/models/UpdateModelTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/models/UpdateModelTransportActionTests.java @@ -651,7 +651,9 @@ public void testUpdateRequestDocIOException() throws IOException { doReturn(FunctionName.TEXT_EMBEDDING).when(mockModel).getAlgorithm(); doReturn(MLModelState.REGISTERED).when(mockModel).getModelState(); - doThrow(new IOException("Exception occurred during building update request.")).when(mockUpdateModelInput).toXContent(any(), any()); + doThrow(new IOException("Exception occurred during building update request.")) + .when(mockUpdateModelInput) + .toXContentForUpdateRequestDoc(any(), any()); transportUpdateModelAction.doExecute(task, mockUpdateModelRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(IOException.class); verify(actionListener).onFailure(argumentCaptor.capture()); @@ -700,7 +702,9 @@ public void testUpdateRequestDocInRegisterToNewModelGroupIOException() throws IO return null; }).when(mlModelGroupManager).getModelGroupResponse(eq("mockUpdateModelGroupId"), isA(ActionListener.class)); - doThrow(new IOException("Exception occurred during building update request.")).when(mockUpdateModelInput).toXContent(any(), any()); + doThrow(new IOException("Exception occurred during building update request.")) + .when(mockUpdateModelInput) + .toXContentForUpdateRequestDoc(any(), any()); transportUpdateModelAction.doExecute(task, mockUpdateModelRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(IOException.class); verify(actionListener).onFailure(argumentCaptor.capture());