Skip to content

Commit

Permalink
add singular get rest actions
Browse files Browse the repository at this point in the history
Signed-off-by: HenryL27 <[email protected]>
Signed-off-by: Xun Zhang <[email protected]>
  • Loading branch information
HenryL27 authored and Zhangxunmt committed Nov 30, 2023
1 parent fa8c697 commit 86d0183
Show file tree
Hide file tree
Showing 12 changed files with 363 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,12 @@ public class ActionConstants {
private final static String BASE_REST_INTERACTION_PATH = "/_plugins/_ml/memory/interaction";
/** path for create conversation */
public final static String CREATE_CONVERSATION_REST_PATH = BASE_REST_PATH + "/_create";
/** path for list conversations */
/** path for get conversations */
public final static String GET_CONVERSATIONS_REST_PATH = BASE_REST_PATH + "/_list";
/** path for update conversations */
public final static String UPDATE_CONVERSATIONS_REST_PATH = BASE_REST_PATH + "/{conversation_id}/_update";
/** path for put interaction */
/** path for create interaction */
public final static String CREATE_INTERACTION_REST_PATH = BASE_REST_PATH + "/{conversation_id}/_create";
/** path for get interactions */
public final static String GET_INTERACTIONS_REST_PATH = BASE_REST_PATH + "/{conversation_id}/_list";
Expand All @@ -79,6 +80,10 @@ public class ActionConstants {
public final static String SEARCH_INTERACTIONS_REST_PATH = BASE_REST_PATH + "/{conversation_id}/_search";
/** path for update interactions */
public final static String UPDATE_INTERACTIONS_REST_PATH = BASE_REST_INTERACTION_PATH + "/{interaction_id}/_update";
/** path for get conversation */
public final static String GET_CONVERSATION_REST_PATH = BASE_REST_PATH + "/{conversation_id}";
/** path for get interaction */
public final static String GET_INTERACTION_REST_PATH = BASE_REST_PATH + "/{conversation_id}/{interaction_id}";

/** default max results returned by get operations */
public final static int DEFAULT_MAX_RESULTS = 10;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ public class GetInteractionsAction extends ActionType<GetInteractionsResponse> {
/** Instance of this */
public static final GetInteractionsAction INSTANCE = new GetInteractionsAction();
/** Name of this action */
public static final String NAME = "cluster:admin/opensearch/ml/memory/interaction/get";
public static final String NAME = "cluster:admin/opensearch/ml/memory/interaction/list";

private GetInteractionsAction() {
super(NAME, GetInteractionsResponse::new);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
public class GetConversationResponseTests extends OpenSearchTestCase {

public void testGetConversationResponseStreaming() throws IOException {
ConversationMeta convo = new ConversationMeta("cid", Instant.now(), "name", null);
ConversationMeta convo = new ConversationMeta("cid", Instant.now(), Instant.now(), "name", null);
GetConversationResponse response = new GetConversationResponse(convo);
assert (response.getConversation().equals(convo));

Expand All @@ -49,12 +49,16 @@ public void testGetConversationResponseStreaming() throws IOException {
}

public void testToXContent() throws IOException {
ConversationMeta convo = new ConversationMeta("cid", Instant.now(), "name", null);
ConversationMeta convo = new ConversationMeta("cid", Instant.now(), Instant.now(), "name", null);
GetConversationResponse response = new GetConversationResponse(convo);
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
response.toXContent(builder, ToXContent.EMPTY_PARAMS);
String result = BytesReference.bytes(builder).utf8ToString();
String expected = "{\"conversation_id\":\"cid\",\"create_time\":\"" + convo.getCreatedTime() + "\",\"name\":\"name\"}";
String expected = "{\"conversation_id\":\"cid\",\"create_time\":\""
+ convo.getCreatedTime()
+ "\"updated_time\":\""
+ convo.getUpdatedTime()
+ "\",\"name\":\"name\"}";
// Sometimes there's an extra trailing 0 in the time stringification, so just assert closeness
LevenshteinDistance ld = new LevenshteinDistance();
assert (ld.getDistance(result, expected) > 0.95);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ public void setup() throws IOException {
}

public void testGetConversation() {
ConversationMeta result = new ConversationMeta("test-cid", Instant.now(), "name", null);
ConversationMeta result = new ConversationMeta("test-cid", Instant.now(), Instant.now(), "name", null);
doAnswer(invocation -> {
ActionListener<ConversationMeta> listener = invocation.getArgument(1);
listener.onResponse(result);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import java.io.IOException;
import java.time.Instant;
import java.util.Collections;

import org.apache.lucene.search.spell.LevenshteinDistance;
import org.opensearch.common.io.stream.BytesStreamOutput;
Expand All @@ -36,7 +37,16 @@
public class GetInteractionResponseTests extends OpenSearchTestCase {

public void testConstructorAndStreaming() throws IOException {
Interaction interaction = new Interaction("iid", Instant.now(), "cid", "inp", "pt", "rsp", "ogn", "extra");
Interaction interaction = new Interaction(
"iid",
Instant.now(),
"cid",
"inp",
"pt",
"rsp",
"ogn",
Collections.singletonMap("meta", "some meta")
);
GetInteractionResponse response = new GetInteractionResponse(interaction);
assert (response.getInteraction().equals(interaction));

Expand All @@ -49,14 +59,23 @@ public void testConstructorAndStreaming() throws IOException {
}

public void testToXContent() throws IOException {
Interaction interaction = new Interaction("iid", Instant.now(), "cid", "inp", "pt", "rsp", "ogn", "extra");
Interaction interaction = new Interaction(
"iid",
Instant.now(),
"cid",
"inp",
"pt",
"rsp",
"ogn",
Collections.singletonMap("meta", "some meta")
);
GetInteractionResponse response = new GetInteractionResponse(interaction);
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
response.toXContent(builder, ToXContent.EMPTY_PARAMS);
String result = BytesReference.bytes(builder).utf8ToString();
String expected = "{\"conversation_id\":\"cid\",\"interaction_id\":\"iid\",\"create_time\":\""
+ interaction.getCreateTime()
+ "\",\"input\":\"inp\",\"prompt_template\":\"pt\",\"response\":\"rsp\",\"origin\":\"ogn\",\"additional_info\":\"extra\"}";
+ "\",\"input\":\"inp\",\"prompt_template\":\"pt\",\"response\":\"rsp\",\"origin\":\"ogn\",\"additional_info\":{\"meta\":\"some meta\"}}";
// Sometimes there's an extra trailing 0 in the time stringification, so just assert closeness
LevenshteinDistance ld = new LevenshteinDistance();
assert (ld.getDistance(result, expected) > 0.95);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

import java.io.IOException;
import java.time.Instant;
import java.util.Collections;
import java.util.Set;

import org.junit.Before;
Expand Down Expand Up @@ -112,7 +113,7 @@ public void testGetInteraction() {
"pt",
"test-response",
"test-origin",
"metadata"
Collections.singletonMap("meta", "some meta")
);
doAnswer(invocation -> {
ActionListener<Interaction> listener = invocation.getArgument(2);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,12 @@
import org.opensearch.ml.memory.action.conversation.CreateInteractionTransportAction;
import org.opensearch.ml.memory.action.conversation.DeleteConversationAction;
import org.opensearch.ml.memory.action.conversation.DeleteConversationTransportAction;
import org.opensearch.ml.memory.action.conversation.GetConversationAction;
import org.opensearch.ml.memory.action.conversation.GetConversationTransportAction;
import org.opensearch.ml.memory.action.conversation.GetConversationsAction;
import org.opensearch.ml.memory.action.conversation.GetConversationsTransportAction;
import org.opensearch.ml.memory.action.conversation.GetInteractionAction;
import org.opensearch.ml.memory.action.conversation.GetInteractionTransportAction;
import org.opensearch.ml.memory.action.conversation.GetInteractionsAction;
import org.opensearch.ml.memory.action.conversation.GetInteractionsTransportAction;
import org.opensearch.ml.memory.action.conversation.GetTracesAction;
Expand Down Expand Up @@ -217,7 +221,9 @@
import org.opensearch.ml.rest.RestMemoryCreateConversationAction;
import org.opensearch.ml.rest.RestMemoryCreateInteractionAction;
import org.opensearch.ml.rest.RestMemoryDeleteConversationAction;
import org.opensearch.ml.rest.RestMemoryGetConversationAction;
import org.opensearch.ml.rest.RestMemoryGetConversationsAction;
import org.opensearch.ml.rest.RestMemoryGetInteractionAction;
import org.opensearch.ml.rest.RestMemoryGetInteractionsAction;
import org.opensearch.ml.rest.RestMemoryGetTracesAction;
import org.opensearch.ml.rest.RestMemorySearchConversationsAction;
Expand Down Expand Up @@ -364,7 +370,9 @@ public class MachineLearningPlugin extends Plugin implements ActionPlugin, Searc
new ActionHandler<>(UpdateInteractionAction.INSTANCE, UpdateInteractionTransportAction.class),
new ActionHandler<>(GetTracesAction.INSTANCE, GetTracesTransportAction.class),
new ActionHandler<>(MLAgentGetAction.INSTANCE, GetAgentTransportAction.class),
new ActionHandler<>(MLAgentDeleteAction.INSTANCE, DeleteAgentTransportAction.class)
new ActionHandler<>(MLAgentDeleteAction.INSTANCE, DeleteAgentTransportAction.class),
new ActionHandler<>(GetConversationAction.INSTANCE, GetConversationTransportAction.class),
new ActionHandler<>(GetInteractionAction.INSTANCE, GetInteractionTransportAction.class)
);
}

Expand Down Expand Up @@ -664,6 +672,8 @@ public List<RestHandler> getRestHandlers(
RestMemoryGetTracesAction restMemoryGetTracesAction = new RestMemoryGetTracesAction();
RestMLGetAgentAction restMLGetAgentAction = new RestMLGetAgentAction();
RestMLDeleteAgentAction restMLDeleteAgentAction = new RestMLDeleteAgentAction();
RestMemoryGetConversationAction restGetConversationAction = new RestMemoryGetConversationAction();
RestMemoryGetInteractionAction restGetInteractionAction = new RestMemoryGetInteractionAction();
return ImmutableList
.of(
restMLStatsAction,
Expand Down Expand Up @@ -707,7 +717,9 @@ public List<RestHandler> getRestHandlers(
restMemoryUpdateInteractionAction,
restMemoryGetTracesAction,
restMLGetAgentAction,
restMLDeleteAgentAction
restMLDeleteAgentAction,
restGetConversationAction,
restGetInteractionAction
);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Copyright 2023 Aryn
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.opensearch.ml.rest;

import java.io.IOException;
import java.util.List;

import org.opensearch.client.node.NodeClient;
import org.opensearch.ml.common.conversation.ActionConstants;
import org.opensearch.ml.memory.action.conversation.GetConversationAction;
import org.opensearch.ml.memory.action.conversation.GetConversationRequest;
import org.opensearch.ml.repackage.com.google.common.collect.ImmutableList;
import org.opensearch.rest.BaseRestHandler;
import org.opensearch.rest.RestRequest;
import org.opensearch.rest.action.RestToXContentListener;

public class RestMemoryGetConversationAction extends BaseRestHandler {
private final static String GET_CONVERSATION_NAME = "conversational_get_conversation";

@Override
public List<Route> routes() {
return ImmutableList.of(new Route(RestRequest.Method.GET, ActionConstants.GET_CONVERSATION_REST_PATH));
}

@Override
public String getName() {
return GET_CONVERSATION_NAME;
}

@Override
public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException {
GetConversationRequest gcRequest = GetConversationRequest.fromRestRequest(request);
return channel -> client.execute(GetConversationAction.INSTANCE, gcRequest, new RestToXContentListener<>(channel));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Copyright 2023 Aryn
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.opensearch.ml.rest;

import java.io.IOException;
import java.util.List;

import org.opensearch.client.node.NodeClient;
import org.opensearch.ml.common.conversation.ActionConstants;
import org.opensearch.ml.memory.action.conversation.GetInteractionAction;
import org.opensearch.ml.memory.action.conversation.GetInteractionRequest;
import org.opensearch.ml.repackage.com.google.common.collect.ImmutableList;
import org.opensearch.rest.BaseRestHandler;
import org.opensearch.rest.RestRequest;
import org.opensearch.rest.action.RestToXContentListener;

public class RestMemoryGetInteractionAction extends BaseRestHandler {
private final static String GET_INTERACTION_NAME = "conversational_get_interaction";

@Override
public List<Route> routes() {
return ImmutableList.of(new Route(RestRequest.Method.GET, ActionConstants.GET_INTERACTION_REST_PATH));
}

@Override
public String getName() {
return GET_INTERACTION_NAME;
}

@Override
public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException {
GetInteractionRequest giRequest = GetInteractionRequest.fromRestRequest(request);
return channel -> client.execute(GetInteractionAction.INSTANCE, giRequest, new RestToXContentListener<>(channel));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/*
* Copyright 2023 Aryn
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.opensearch.ml.rest;

import java.io.IOException;
import java.util.Map;

import org.apache.http.HttpEntity;
import org.apache.http.HttpHeaders;
import org.apache.http.message.BasicHeader;
import org.junit.Before;
import org.opensearch.client.Response;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.ml.common.conversation.ActionConstants;
import org.opensearch.ml.settings.MLCommonsSettings;
import org.opensearch.ml.utils.TestHelper;

import com.google.common.collect.ImmutableList;

public class RestMemoryGetConversationActionIT extends MLCommonsRestTestCase {
@Before
public void setupFeatureSettings() throws IOException {
Response response = TestHelper
.makeRequest(
client(),
"PUT",
"_cluster/settings",
null,
"{\"persistent\":{\"" + MLCommonsSettings.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey() + "\":true}}",
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, ""))
);
assertEquals(200, response.getStatusLine().getStatusCode());
}

public void testGetConversation() throws IOException {
Response ccresponse = TestHelper
.makeRequest(
client(),
"POST",
ActionConstants.CREATE_CONVERSATION_REST_PATH,
null,
gson.toJson(Map.of(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD, "name")),
null
);
assert (ccresponse != null);
assert (TestHelper.restStatus(ccresponse) == RestStatus.OK);
HttpEntity cchttpEntity = ccresponse.getEntity();
String ccentityString = TestHelper.httpEntityToString(cchttpEntity);
@SuppressWarnings("unchecked")
Map<String, String> ccmap = gson.fromJson(ccentityString, Map.class);
assert (ccmap.containsKey(ActionConstants.CONVERSATION_ID_FIELD));
String id = (String) ccmap.get(ActionConstants.CONVERSATION_ID_FIELD);

Response gcresponse = TestHelper
.makeRequest(client(), "GET", ActionConstants.GET_CONVERSATION_REST_PATH.replace("{conversation_id}", id), null, "", null);
assert (gcresponse != null);
assert (TestHelper.restStatus(gcresponse) == RestStatus.OK);
HttpEntity gchttpEntity = gcresponse.getEntity();
String gcentitiyString = TestHelper.httpEntityToString(gchttpEntity);
@SuppressWarnings("unchecked")
Map<String, String> gcmap = gson.fromJson(gcentitiyString, Map.class);
assert (gcmap.containsKey(ActionConstants.CONVERSATION_ID_FIELD) && gcmap.get(ActionConstants.CONVERSATION_ID_FIELD).equals(id));
assert (gcmap.containsKey(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD)
&& gcmap.get(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD).equals("name"));
}
}
Loading

0 comments on commit 86d0183

Please sign in to comment.