Skip to content

Commit

Permalink
add new data fields in the memory layer and update tests
Browse files Browse the repository at this point in the history
Signed-off-by: Xun Zhang <[email protected]>
  • Loading branch information
Zhangxunmt committed Dec 2, 2023
1 parent 6b23dc9 commit 9f9a0ce
Show file tree
Hide file tree
Showing 35 changed files with 1,043 additions and 351 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ public class ActionConstants {
public final static String PROMPT_TEMPLATE_FIELD = "prompt_template";
/** name of metadata field in all requests */
public final static String ADDITIONAL_INFO_FIELD = "additional_info";
/** name of metadata field in all requests */
public final static String PARENT_INTERACTION_ID_FIELD = "parent_interaction_id";
/** name of metadata field in all requests */
public final static String TRACE_NUMBER_FIELD = "trace_number";
/** name of success field in all requests */
public final static String SUCCESS_FIELD = "success";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ public class ConversationMeta implements Writeable, ToXContentObject {
@Getter
private Instant createdTime;
@Getter
private Instant updatedTime;
@Getter
private String name;
@Getter
private String user;
Expand All @@ -66,9 +68,10 @@ public static ConversationMeta fromSearchHit(SearchHit hit) {
*/
public static ConversationMeta fromMap(String id, Map<String, Object> docFields) {
Instant created = Instant.parse((String) docFields.get(ConversationalIndexConstants.META_CREATED_FIELD));
Instant updated = Instant.parse((String) docFields.get(ConversationalIndexConstants.META_UPDATED_FIELD));
String name = (String) docFields.get(ConversationalIndexConstants.META_NAME_FIELD);
String user = (String) docFields.get(ConversationalIndexConstants.USER_FIELD);
return new ConversationMeta(id, created, name, user);
return new ConversationMeta(id, created, updated, name, user);
}

/**
Expand All @@ -81,15 +84,17 @@ public static ConversationMeta fromMap(String id, Map<String, Object> docFields)
public static ConversationMeta fromStream(StreamInput in) throws IOException {
String id = in.readString();
Instant created = in.readInstant();
Instant updated = in.readInstant();
String name = in.readString();
String user = in.readOptionalString();
return new ConversationMeta(id, created, name, user);
return new ConversationMeta(id, created, updated, name, user);
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(id);
out.writeInstant(createdTime);
out.writeInstant(updatedTime);
out.writeString(name);
out.writeOptionalString(user);
}
Expand All @@ -104,6 +109,7 @@ public IndexRequest toIndexRequest(String index) {
IndexRequest request = new IndexRequest(index);
return request.id(this.id).source(
ConversationalIndexConstants.META_CREATED_FIELD, this.createdTime,
ConversationalIndexConstants.META_UPDATED_FIELD, this.updatedTime,
ConversationalIndexConstants.META_NAME_FIELD, this.name
);
}
Expand All @@ -113,6 +119,7 @@ public String toString() {
return "{id=" + id
+ ", name=" + name
+ ", created=" + createdTime.toString()
+ ", updated=" + updatedTime.toString()
+ ", user=" + user
+ "}";
}
Expand All @@ -122,6 +129,7 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContentObject.Para
builder.startObject();
builder.field(ActionConstants.CONVERSATION_ID_FIELD, this.id);
builder.field(ConversationalIndexConstants.META_CREATED_FIELD, this.createdTime);
builder.field(ConversationalIndexConstants.META_UPDATED_FIELD, this.updatedTime);
builder.field(ConversationalIndexConstants.META_NAME_FIELD, this.name);
if(this.user != null) {
builder.field(ConversationalIndexConstants.USER_FIELD, this.user);
Expand All @@ -137,9 +145,10 @@ public boolean equals(Object other) {
}
ConversationMeta otherConversation = (ConversationMeta) other;
return Objects.equals(this.id, otherConversation.id) &&
Objects.equals(this.user, otherConversation.user) &&
Objects.equals(this.createdTime, otherConversation.createdTime) &&
Objects.equals(this.name, otherConversation.name);
Objects.equals(this.user, otherConversation.user) &&
Objects.equals(this.createdTime, otherConversation.createdTime) &&
Objects.equals(this.updatedTime, otherConversation.updatedTime) &&
Objects.equals(this.name, otherConversation.name);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,14 @@ public class ConversationalIndexConstants {
public final static String META_INDEX_NAME = ".plugins-ml-conversation-meta";
/** Name of the metadata field for initial timestamp */
public final static String META_CREATED_FIELD = "create_time";
/** Name of the metadata field for updated timestamp */
public final static String META_UPDATED_FIELD = "updated_time";
/** Name of the metadata field for name of the conversation */
public final static String META_NAME_FIELD = "name";
/** Name of the owning user field in all indices */
public final static String USER_FIELD = "user";
/** Name of the application that created this conversation */
public final static String APPLICATION_TYPE_FIELD = "application_type";
/** Mappings for the conversational metadata index */
public final static String META_MAPPING = "{\n"
+ " \"_meta\": {\n"
Expand All @@ -41,12 +45,18 @@ public class ConversationalIndexConstants {
+ " \"properties\": {\n"
+ " \""
+ META_NAME_FIELD
+ "\": {\"type\": \"keyword\"},\n"
+ "\": {\"type\": \"text\"},\n"
+ " \""
+ META_CREATED_FIELD
+ "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n"
+ " \""
+ META_UPDATED_FIELD
+ "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n"
+ " \""
+ USER_FIELD
+ "\": {\"type\": \"keyword\"},\n"
+ " \""
+ APPLICATION_TYPE_FIELD
+ "\": {\"type\": \"keyword\"}\n"
+ " }\n"
+ "}";
Expand All @@ -69,6 +79,10 @@ public class ConversationalIndexConstants {
public final static String INTERACTIONS_ADDITIONAL_INFO_FIELD = "additional_info";
/** Name of the interaction field for the timestamp */
public final static String INTERACTIONS_CREATE_TIME_FIELD = "create_time";
/** Name of the interaction id */
public final static String PARENT_INTERACTIONS_ID_FIELD = "parent_interaction_id";
/** The trace number of an interaction */
public final static String INTERACTIONS_TRACE_NUMBER_FIELD = "trace_number";
/** Mappings for the interactions index */
public final static String INTERACTIONS_MAPPINGS = "{\n"
+ " \"_meta\": {\n"
Expand All @@ -95,7 +109,13 @@ public class ConversationalIndexConstants {
+ "\": {\"type\": \"keyword\"},\n"
+ " \""
+ INTERACTIONS_ADDITIONAL_INFO_FIELD
+ "\": {\"type\": \"text\"}\n"
+ "\": {\"type\": \"flat_object\"},\n"
+ " \""
+ PARENT_INTERACTIONS_ID_FIELD
+ "\": {\"type\": \"keyword\"},\n"
+ " \""
+ INTERACTIONS_TRACE_NUMBER_FIELD
+ "\": {\"type\": \"long\"}\n"
+ " }\n"
+ "}";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import java.io.IOException;
import java.time.Instant;
import java.util.HashMap;
import java.util.Map;

import org.opensearch.core.common.io.stream.StreamInput;
Expand Down Expand Up @@ -54,7 +55,24 @@ public class Interaction implements Writeable, ToXContentObject {
@Getter
private String origin;
@Getter
private String additionalInfo;
private Map<String, String> additionalInfo;
@Getter
private String parentInteractionId;
@Getter
private Integer traceNum;

public Interaction(String id, Instant createTime, String conversationId, String input, String promptTemplate, String response, String origin, Map<String, String> additionalInfo) {
this.id = id;
this.createTime = createTime;
this.conversationId = conversationId;
this.input = input;
this.promptTemplate = promptTemplate;
this.response = response;
this.origin = origin;
this.additionalInfo = additionalInfo;
this.parentInteractionId = null;
this.traceNum = null;
}

/**
* Creates an Interaction object from a map of fields in the OS index
Expand All @@ -69,7 +87,9 @@ public static Interaction fromMap(String id, Map<String, Object> fields) {
String promptTemplate = (String) fields.get(ConversationalIndexConstants.INTERACTIONS_PROMPT_TEMPLATE_FIELD);
String response = (String) fields.get(ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD);
String origin = (String) fields.get(ConversationalIndexConstants.INTERACTIONS_ORIGIN_FIELD);
String additionalInfo = (String) fields.get(ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD);
Map<String,String> additionalInfo = (Map<String,String>) fields.get(ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD);
String parentInteractionId = (String) fields.getOrDefault(ConversationalIndexConstants.PARENT_INTERACTIONS_ID_FIELD, null);
Integer traceNum = (Integer) fields.getOrDefault(ConversationalIndexConstants.INTERACTIONS_TRACE_NUMBER_FIELD, null);
return new Interaction(id, createTime, conversationId, input, promptTemplate, response, origin, additionalInfo);
}

Expand Down Expand Up @@ -97,7 +117,12 @@ public static Interaction fromStream(StreamInput in) throws IOException {
String promptTemplate = in.readString();
String response = in.readString();
String origin = in.readString();
String additionalInfo = in.readOptionalString();
Map<String, String> additionalInfo = new HashMap<>();
if (in.readBoolean()) {
additionalInfo = in.readMap(s -> s.readString(), s -> s.readString());
}
String parentInteractionId = in.readOptionalString();
Integer traceNum = in.readOptionalInt();
return new Interaction(id, createTime, conversationId, input, promptTemplate, response, origin, additionalInfo);
}

Expand All @@ -111,7 +136,14 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeString(promptTemplate);
out.writeString(response);
out.writeString(origin);
out.writeOptionalString(additionalInfo);
if (additionalInfo != null) {
out.writeBoolean(true);
out.writeMap(additionalInfo, StreamOutput::writeString, StreamOutput::writeString);
} else {
out.writeBoolean(false);
}
out.writeOptionalString(parentInteractionId);
out.writeOptionalInt(traceNum);
}

@Override
Expand All @@ -127,6 +159,12 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContentObject.Para
if(additionalInfo != null) {
builder.field(ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD, additionalInfo);
}
if (parentInteractionId != null) {
builder.field(ConversationalIndexConstants.PARENT_INTERACTIONS_ID_FIELD, parentInteractionId);
}
if (traceNum != null) {
builder.field(ConversationalIndexConstants.INTERACTIONS_TRACE_NUMBER_FIELD, traceNum);
}
builder.endObject();
return builder;
}
Expand All @@ -143,7 +181,12 @@ public boolean equals(Object other) {
((Interaction) other).response.equals(this.response) &&
((Interaction) other).origin.equals(this.origin) &&
( (((Interaction) other).additionalInfo == null && this.additionalInfo == null) ||
((Interaction) other).additionalInfo.equals(this.additionalInfo))
((Interaction) other).additionalInfo.equals(this.additionalInfo)) &&
( (((Interaction) other).parentInteractionId == null && this.parentInteractionId == null) ||
((Interaction) other).parentInteractionId.equals(this.parentInteractionId)) &&
( (((Interaction) other).traceNum == null && this.traceNum == null) ||
((Interaction) other).traceNum.equals(this.traceNum))

);
}

Expand All @@ -158,8 +201,9 @@ public String toString() {
+ ",promt_template=" + promptTemplate
+ ",response=" + response
+ ",additional_info=" + additionalInfo
+ ",parentInteractionId=" + parentInteractionId
+ ",traceNum=" + traceNum
+ "}";
}


}
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@
package org.opensearch.ml.memory;

import java.util.List;
import java.util.Map;

import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.update.UpdateResponse;
import org.opensearch.common.action.ActionFuture;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.conversation.ConversationMeta;
Expand Down Expand Up @@ -58,6 +60,14 @@ public interface ConversationalMemoryHandler {
*/
public ActionFuture<String> createConversation(String name);

/**
* Create a new conversation
* @param name the name of the new conversation
* @param applicationType the application that creates this conversation
* @param listener listener to wait for this op to finish, gets unique id of new conversation
*/
public void createConversation(String name, String applicationType, ActionListener<String> listener);

/**
* Adds an interaction to the conversation indicated, updating the conversational metadata
* @param conversationId the conversation to add the interaction to
Expand All @@ -74,7 +84,7 @@ public void createInteraction(
String promptTemplate,
String response,
String origin,
String additionalInfo,
Map<String, String> additionalInfo,
ActionListener<String> listener
);

Expand All @@ -94,7 +104,31 @@ public ActionFuture<String> createInteraction(
String promptTemplate,
String response,
String origin,
String additionalInfo
Map<String, String> additionalInfo
);

/**
* Adds an interaction to the conversation indicated, updating the conversational metadata
* @param conversationId the conversation to add the interaction to
* @param input the human input for the interaction
* @param promptTemplate the prompt template used for this interaction
* @param response the Gen AI response for this interaction
* @param origin the name of the GenAI agent in this interaction
* @param additionalInfo additional information used in constructing the LLM prompt
* @param interactionId the parent interactionId of this interaction
* @param traceNumber the trace number for a parent interaction
* @param listener gets the ID of the new interaction
*/
public void createInteraction(
String conversationId,
String input,
String promptTemplate,
String response,
String origin,
Map<String, String> additionalInfo,
ActionListener<String> listener,
String interactionId,
Integer traceNumber
);

/**
Expand All @@ -120,6 +154,15 @@ public ActionFuture<String> createInteraction(
*/
public void getInteractions(String conversationId, int from, int maxResults, ActionListener<List<Interaction>> listener);

/**
* Get the traces associate with this interaction, sorted by recency
* @param interactionId the interaction whose traces to get
* @param from where to start listing from
* @param maxResults how many traces to get
* @param listener gets the list of traces in this conversation, sorted by recency
*/
public void getTraces(String interactionId, int from, int maxResults, ActionListener<List<Interaction>> listener);

/**
* Get the interactions associate with this conversation, sorted by recency
* @param conversationId the conversation whose interactions to get
Expand Down Expand Up @@ -203,6 +246,13 @@ public ActionFuture<String> createInteraction(
*/
public ActionFuture<SearchResponse> searchInteractions(String conversationId, SearchRequest request);

/**
* Update a conversation
* @param updateContent update content for the conversations index
* @param listener receives the update response
*/
public void updateConversation(String conversationId, Map<String, Object> updateContent, ActionListener<UpdateResponse> listener);

/**
* Get a single ConversationMeta object
* @param conversationId id of the conversation to get
Expand Down
Loading

0 comments on commit 9f9a0ce

Please sign in to comment.