Skip to content

Commit

Permalink
Fix internal connector (opensearch-project#1989)
Browse files Browse the repository at this point in the history
* Fix internal connector

Signed-off-by: Sicheng Song <[email protected]>

* spotless fix

Signed-off-by: Sicheng Song <[email protected]>

---------

Signed-off-by: Sicheng Song <[email protected]>
  • Loading branch information
b4sjoo authored Feb 2, 2024
1 parent 0755e50 commit 343ae16
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,44 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
return builder;
}

public XContentBuilder toXContentForUpdateRequestDoc(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(MODEL_ID_FIELD, modelId);
if (name != null) {
builder.field(MODEL_NAME_FIELD, name);
}
if (description != null) {
builder.field(DESCRIPTION_FIELD, description);
}
if (version != null) {
builder.field(MODEL_VERSION_FIELD, version);
}
if (modelGroupId != null) {
builder.field(MODEL_GROUP_ID_FIELD, modelGroupId);
}
if (isEnabled != null) {
builder.field(IS_ENABLED_FIELD, isEnabled);
}
if (rateLimiter != null) {
builder.field(RATE_LIMITER_FIELD, rateLimiter);
}
if (modelConfig != null) {
builder.field(MODEL_CONFIG_FIELD, modelConfig);
}
// Notice that we serialize the updatedConnector to the connector field, in order to be compatible with original internal connector field format.
if (updatedConnector != null) {
builder.field(CONNECTOR_FIELD, updatedConnector);
}
if (connectorId != null) {
builder.field(CONNECTOR_ID_FIELD, connectorId);
}
if (lastUpdateTime != null) {
builder.field(LAST_UPDATED_TIME_FIELD, lastUpdateTime.toEpochMilli());
}
builder.endObject();
return builder;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(modelId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,23 @@ public class MLUpdateModelInputTest {
+
"\"test-connector_id\",\"connector\":{\"description\":\"updated description\",\"version\":\"1\"},\"last_updated_time\":1}";

private final String expectedOutputStrForUpdateRequestDoc = "{\"model_id\":\"test-model_id\",\"name\":\"name\",\"description\":\"description\",\"model_version\":"
+
"\"2\",\"model_group_id\":\"modelGroupId\",\"is_enabled\":false,\"rate_limiter\":" +
"{\"limit\":\"1\",\"unit\":\"MILLISECONDS\"},\"model_config\":" +
"{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\""
+
"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"},\"connector\":" +
"{\"name\":\"test\",\"version\":\"1\",\"protocol\":\"http\",\"parameters\":{\"param1\":\"value1\"},\"credential\":"
+
"{\"api_key\":\"credential_value\"},\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":"
+
"\"https://api.openai.com/v1/chat/completions\",\"headers\":{\"Authorization\":\"Bearer ${credential.api_key}\"},\"request_body\":"
+
"\"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"messages\\\": ${parameters.messages} }\"}]},\"connector_id\":"
+
"\"test-connector_id\",\"last_updated_time\":1}";

private final String expectedOutputStr = "{\"model_id\":null,\"name\":\"name\",\"description\":\"description\",\"model_group_id\":"
+
"\"modelGroupId\",\"is_enabled\":false,\"rate_limiter\":" +
Expand Down Expand Up @@ -153,6 +170,21 @@ public void testToXContent() throws Exception {
assertEquals(expectedInputStr, jsonStr);
}

@Test
public void testToXContentForUpdateRequestDoc() throws Exception {
String jsonStr = serializationWithToXContentForUpdateRequestDoc(updateModelInput);
assertEquals(expectedOutputStrForUpdateRequestDoc, jsonStr);
}

@Test
public void testToXContenttForUpdateRequestDocIncomplete() throws Exception {
String expectedIncompleteInputStr = "{\"model_id\":\"test-model_id\"}";
updateModelInput = MLUpdateModelInput.builder()
.modelId("test-model_id").build();
String jsonStr = serializationWithToXContentForUpdateRequestDoc(updateModelInput);
assertEquals(expectedIncompleteInputStr, jsonStr);
}

@Test
public void testToXContentIncomplete() throws Exception {
String expectedIncompleteInputStr = "{\"model_id\":\"test-model_id\"}";
Expand Down Expand Up @@ -237,4 +269,11 @@ private String serializationWithToXContent(MLUpdateModelInput input) throws IOEx
assertNotNull(builder);
return builder.toString();
}

private String serializationWithToXContentForUpdateRequestDoc(MLUpdateModelInput input) throws IOException {
XContentBuilder builder = XContentFactory.jsonBuilder();
input.toXContentForUpdateRequestDoc(builder, ToXContent.EMPTY_PARAMS);
assertNotNull(builder);
return builder.toString();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ private void updateRemoteOrTextEmbeddingModel(
MLModel mlModel,
User user,
ActionListener<UpdateResponse> wrappedListener
) {
) throws IOException {
String newModelGroupId = (Strings.hasLength(updateModelInput.getModelGroupId())
&& !Objects.equals(updateModelInput.getModelGroupId(), mlModel.getModelGroupId())) ? updateModelInput.getModelGroupId() : null;
String newConnectorId = Strings.hasLength(updateModelInput.getConnectorId()) ? updateModelInput.getConnectorId() : null;
Expand Down Expand Up @@ -330,7 +330,7 @@ private void updateModelWithRegisteringToAnotherModelGroup(
.validateModelGroupAccess(user, newModelGroupId, client, ActionListener.wrap(hasNewModelGroupPermission -> {
if (hasNewModelGroupPermission) {
mlModelGroupManager.getModelGroupResponse(newModelGroupId, ActionListener.wrap(newModelGroupResponse -> {
updateRequestConstructor(
buildUpdateRequest(
modelId,
newModelGroupId,
updateRequest,
Expand Down Expand Up @@ -364,11 +364,11 @@ private void updateModelWithRegisteringToAnotherModelGroup(
wrappedListener.onFailure(exception);
}));
} else {
updateRequestConstructor(modelId, updateRequest, updateModelInput, wrappedListener, isUpdateModelCache);
buildUpdateRequest(modelId, updateRequest, updateModelInput, wrappedListener, isUpdateModelCache);
}
}

private void updateRequestConstructor(
private void buildUpdateRequest(
String modelId,
UpdateRequest updateRequest,
MLUpdateModelInput updateModelInput,
Expand All @@ -377,7 +377,7 @@ private void updateRequestConstructor(
) {
try {
updateModelInput.setLastUpdateTime(Instant.now());
updateRequest.doc(updateModelInput.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS));
updateRequest.doc(updateModelInput.toXContentForUpdateRequestDoc(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS));
updateRequest.docAsUpsert(true);
updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
if (isUpdateModelCache) {
Expand All @@ -397,7 +397,7 @@ private void updateRequestConstructor(
}
}

private void updateRequestConstructor(
private void buildUpdateRequest(
String modelId,
String newModelGroupId,
UpdateRequest updateRequest,
Expand All @@ -418,7 +418,7 @@ private void updateRequestConstructor(
Integer.parseInt(updatedVersion)
);
try {
updateRequest.doc(updateModelInput.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS));
updateRequest.doc(updateModelInput.toXContentForUpdateRequestDoc(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS));
updateRequest.docAsUpsert(true);
updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
if (isUpdateModelCache) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -651,7 +651,9 @@ public void testUpdateRequestDocIOException() throws IOException {
doReturn(FunctionName.TEXT_EMBEDDING).when(mockModel).getAlgorithm();
doReturn(MLModelState.REGISTERED).when(mockModel).getModelState();

doThrow(new IOException("Exception occurred during building update request.")).when(mockUpdateModelInput).toXContent(any(), any());
doThrow(new IOException("Exception occurred during building update request."))
.when(mockUpdateModelInput)
.toXContentForUpdateRequestDoc(any(), any());
transportUpdateModelAction.doExecute(task, mockUpdateModelRequest, actionListener);
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(IOException.class);
verify(actionListener).onFailure(argumentCaptor.capture());
Expand Down Expand Up @@ -700,7 +702,9 @@ public void testUpdateRequestDocInRegisterToNewModelGroupIOException() throws IO
return null;
}).when(mlModelGroupManager).getModelGroupResponse(eq("mockUpdateModelGroupId"), isA(ActionListener.class));

doThrow(new IOException("Exception occurred during building update request.")).when(mockUpdateModelInput).toXContent(any(), any());
doThrow(new IOException("Exception occurred during building update request."))
.when(mockUpdateModelInput)
.toXContentForUpdateRequestDoc(any(), any());
transportUpdateModelAction.doExecute(task, mockUpdateModelRequest, actionListener);
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(IOException.class);
verify(actionListener).onFailure(argumentCaptor.capture());
Expand Down

0 comments on commit 343ae16

Please sign in to comment.