diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParameters.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParameters.java index ba4f1c9b03..d5ec0e47c1 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParameters.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParameters.java @@ -17,25 +17,25 @@ */ package org.opensearch.searchpipelines.questionanswering.generative.ext; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + import java.io.IOException; import java.util.ArrayList; import java.util.List; import java.util.Objects; -import org.opensearch.core.ParseField; import org.opensearch.core.common.Strings; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable; -import org.opensearch.core.xcontent.ObjectParser; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.searchpipelines.questionanswering.generative.GenerativeQAProcessorConstants; import org.opensearch.searchpipelines.questionanswering.generative.llm.MessageBlock; import com.google.common.base.Preconditions; +import lombok.Builder; import lombok.Getter; import lombok.NoArgsConstructor; import lombok.Setter; @@ -48,60 +48,44 @@ @NoArgsConstructor public class GenerativeQAParameters implements Writeable, ToXContentObject { - private static final ObjectParser PARSER; - // Optional parameter; if provided, conversational memory will be used for RAG // and the current interaction will be saved in the conversation referenced by this id. - private static final ParseField CONVERSATION_ID = new ParseField("memory_id"); + private static final String CONVERSATION_ID = "memory_id"; // Optional parameter; if an LLM model is not set at the search pipeline level, one must be // provided at the search request level. - private static final ParseField LLM_MODEL = new ParseField("llm_model"); + private static final String LLM_MODEL = "llm_model"; // Required parameter; this is sent to LLMs as part of the user prompt. // TODO support question rewriting when chat history is not used (conversation_id is not provided). - private static final ParseField LLM_QUESTION = new ParseField("llm_question"); + private static final String LLM_QUESTION = "llm_question"; // Optional parameter; this parameter controls the number of search results ("contexts") to // include in the user prompt. - private static final ParseField CONTEXT_SIZE = new ParseField("context_size"); + private static final String CONTEXT_SIZE = "context_size"; // Optional parameter; this parameter controls the number of the interactions to include // in the user prompt. - private static final ParseField INTERACTION_SIZE = new ParseField("message_size"); + private static final String INTERACTION_SIZE = "message_size"; // Optional parameter; this parameter controls how long the search pipeline waits for a response // from a remote inference endpoint before timing out the request. - private static final ParseField TIMEOUT = new ParseField("timeout"); + private static final String TIMEOUT = "timeout"; // Optional parameter: this parameter allows request-level customization of the "system" (role) prompt. - private static final ParseField SYSTEM_PROMPT = new ParseField(GenerativeQAProcessorConstants.CONFIG_NAME_SYSTEM_PROMPT); + private static final String SYSTEM_PROMPT = "system_prompt"; // Optional parameter: this parameter allows request-level customization of the "user" (role) prompt. - private static final ParseField USER_INSTRUCTIONS = new ParseField(GenerativeQAProcessorConstants.CONFIG_NAME_USER_INSTRUCTIONS); + private static final String USER_INSTRUCTIONS = "user_instructions"; // Optional parameter; this parameter indicates the name of the field in the LLM response // that contains the chat completion text, i.e. "answer". - private static final ParseField LLM_RESPONSE_FIELD = new ParseField("llm_response_field"); + private static final String LLM_RESPONSE_FIELD = "llm_response_field"; - private static final ParseField LLM_MESSAGES_FIELD = new ParseField("llm_messages"); + private static final String LLM_MESSAGES_FIELD = "llm_messages"; public static final int SIZE_NULL_VALUE = -1; - static { - PARSER = new ObjectParser<>("generative_qa_parameters", GenerativeQAParameters::new); - PARSER.declareString(GenerativeQAParameters::setConversationId, CONVERSATION_ID); - PARSER.declareString(GenerativeQAParameters::setLlmModel, LLM_MODEL); - PARSER.declareString(GenerativeQAParameters::setLlmQuestion, LLM_QUESTION); - PARSER.declareStringOrNull(GenerativeQAParameters::setSystemPrompt, SYSTEM_PROMPT); - PARSER.declareStringOrNull(GenerativeQAParameters::setUserInstructions, USER_INSTRUCTIONS); - PARSER.declareIntOrNull(GenerativeQAParameters::setContextSize, SIZE_NULL_VALUE, CONTEXT_SIZE); - PARSER.declareIntOrNull(GenerativeQAParameters::setInteractionSize, SIZE_NULL_VALUE, INTERACTION_SIZE); - PARSER.declareIntOrNull(GenerativeQAParameters::setTimeout, SIZE_NULL_VALUE, TIMEOUT); - PARSER.declareStringOrNull(GenerativeQAParameters::setLlmResponseField, LLM_RESPONSE_FIELD); - PARSER.declareObjectArray(GenerativeQAParameters::setMessageBlock, (p, c) -> MessageBlock.fromXContent(p), LLM_MESSAGES_FIELD); - } - @Setter @Getter private String conversationId; @@ -167,6 +151,7 @@ public GenerativeQAParameters( ); } + @Builder(toBuilder = true) public GenerativeQAParameters( String conversationId, String llmModel, @@ -184,7 +169,7 @@ public GenerativeQAParameters( // TODO: keep this requirement until we can extract the question from the query or from the request processor parameters // for question rewriting. - Preconditions.checkArgument(!Strings.isNullOrEmpty(llmQuestion), LLM_QUESTION.getPreferredName() + " must be provided."); + Preconditions.checkArgument(!Strings.isNullOrEmpty(llmQuestion), LLM_QUESTION + " must be provided."); this.llmQuestion = llmQuestion; this.systemPrompt = systemPrompt; this.userInstructions = userInstructions; @@ -212,17 +197,49 @@ public GenerativeQAParameters(StreamInput input) throws IOException { @Override public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params) throws IOException { - return xContentBuilder - .field(CONVERSATION_ID.getPreferredName(), this.conversationId) - .field(LLM_MODEL.getPreferredName(), this.llmModel) - .field(LLM_QUESTION.getPreferredName(), this.llmQuestion) - .field(SYSTEM_PROMPT.getPreferredName(), this.systemPrompt) - .field(USER_INSTRUCTIONS.getPreferredName(), this.userInstructions) - .field(CONTEXT_SIZE.getPreferredName(), this.contextSize) - .field(INTERACTION_SIZE.getPreferredName(), this.interactionSize) - .field(TIMEOUT.getPreferredName(), this.timeout) - .field(LLM_RESPONSE_FIELD.getPreferredName(), this.llmResponseField) - .field(LLM_MESSAGES_FIELD.getPreferredName(), this.llmMessages); + xContentBuilder.startObject(); + if (this.conversationId != null) { + xContentBuilder.field(CONVERSATION_ID, this.conversationId); + } + + if (this.llmModel != null) { + xContentBuilder.field(LLM_MODEL, this.llmModel); + } + + if (this.llmQuestion != null) { + xContentBuilder.field(LLM_QUESTION, this.llmQuestion); + } + + if (this.systemPrompt != null) { + xContentBuilder.field(SYSTEM_PROMPT, this.systemPrompt); + } + + if (this.userInstructions != null) { + xContentBuilder.field(USER_INSTRUCTIONS, this.userInstructions); + } + + if (this.contextSize != null) { + xContentBuilder.field(CONTEXT_SIZE, this.contextSize); + } + + if (this.interactionSize != null) { + xContentBuilder.field(INTERACTION_SIZE, this.interactionSize); + } + + if (this.timeout != null) { + xContentBuilder.field(TIMEOUT, this.timeout); + } + + if (this.llmResponseField != null) { + xContentBuilder.field(LLM_RESPONSE_FIELD, this.llmResponseField); + } + + if (this.llmMessages != null && !this.llmMessages.isEmpty()) { + xContentBuilder.field(LLM_MESSAGES_FIELD, this.llmMessages); + } + + xContentBuilder.endObject(); + return xContentBuilder; } @Override @@ -242,7 +259,76 @@ public void writeTo(StreamOutput out) throws IOException { } public static GenerativeQAParameters parse(XContentParser parser) throws IOException { - return PARSER.parse(parser, null); + String conversationId = null; + String llmModel = null; + String llmQuestion = null; + String systemPrompt = null; + String userInstructions = null; + Integer contextSize = null; + Integer interactionSize = null; + Integer timeout = null; + String llmResponseField = null; + List llmMessages = null; + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String field = parser.currentName(); + parser.nextToken(); + + switch (field) { + case CONVERSATION_ID: + conversationId = parser.text(); + break; + case LLM_MODEL: + llmModel = parser.text(); + break; + case LLM_QUESTION: + llmQuestion = parser.text(); + break; + case SYSTEM_PROMPT: + systemPrompt = parser.text(); + break; + case USER_INSTRUCTIONS: + userInstructions = parser.text(); + break; + case CONTEXT_SIZE: + contextSize = parser.intValue(); + break; + case INTERACTION_SIZE: + interactionSize = parser.intValue(); + break; + case TIMEOUT: + timeout = parser.intValue(); + break; + case LLM_RESPONSE_FIELD: + llmResponseField = parser.text(); + break; + case LLM_MESSAGES_FIELD: + llmMessages = new ArrayList<>(); + ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_ARRAY) { + llmMessages.add(MessageBlock.fromXContent(parser)); + } + break; + default: + parser.skipChildren(); + break; + } + } + + return GenerativeQAParameters + .builder() + .conversationId(conversationId) + .llmModel(llmModel) + .llmQuestion(llmQuestion) + .systemPrompt(systemPrompt) + .userInstructions(userInstructions) + .contextSize(contextSize) + .interactionSize(interactionSize) + .timeout(timeout) + .llmResponseField(llmResponseField) + .llmMessages(llmMessages) + .build(); } @Override diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamExtBuilderTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamExtBuilderTests.java index 23eb6f3d3a..2772884f11 100644 --- a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamExtBuilderTests.java +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamExtBuilderTests.java @@ -21,20 +21,25 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; +import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; import java.io.EOFException; import java.io.IOException; +import java.util.Collections; import java.util.List; import java.util.Map; +import org.junit.Assert; import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.core.xcontent.XContentHelper; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.search.SearchModule; import org.opensearch.searchpipelines.questionanswering.generative.llm.MessageBlock; import org.opensearch.test.OpenSearchTestCase; @@ -121,21 +126,38 @@ public void testMiscMethods() throws IOException { } public void testParse() throws IOException { - XContentParser xcParser = mock(XContentParser.class); - when(xcParser.nextToken()).thenReturn(XContentParser.Token.START_OBJECT).thenReturn(XContentParser.Token.END_OBJECT); - GenerativeQAParamExtBuilder builder = GenerativeQAParamExtBuilder.parse(xcParser); + String requiredJsonStr = "{\"llm_question\":\"this is test llm question\"}"; + + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + null, + requiredJsonStr + ); + + parser.nextToken(); + GenerativeQAParamExtBuilder builder = GenerativeQAParamExtBuilder.parse(parser); assertNotNull(builder); assertNotNull(builder.getParams()); + GenerativeQAParameters params = builder.getParams(); + Assert.assertEquals("this is test llm question", params.getLlmQuestion()); } public void testXContentRoundTrip() throws IOException { GenerativeQAParameters param1 = new GenerativeQAParameters("a", "b", "c", "s", "u", null, null, null, null, messageList); GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder(); extBuilder.setParams(param1); + XContentType xContentType = randomFrom(XContentType.values()); - BytesReference serialized = XContentHelper.toXContent(extBuilder, xContentType, true); + XContentBuilder builder = XContentBuilder.builder(xContentType.xContent()); + builder = extBuilder.toXContent(builder, EMPTY_PARAMS); + BytesReference serialized = BytesReference.bytes(builder); + XContentParser parser = createParser(xContentType.xContent(), serialized); + parser.nextToken(); GenerativeQAParamExtBuilder deserialized = GenerativeQAParamExtBuilder.parse(parser); + assertEquals(extBuilder, deserialized); GenerativeQAParameters parameters = deserialized.getParams(); assertTrue(GenerativeQAParameters.SIZE_NULL_VALUE == parameters.getContextSize()); @@ -147,10 +169,16 @@ public void testXContentRoundTripAllValues() throws IOException { GenerativeQAParameters param1 = new GenerativeQAParameters("a", "b", "c", "s", "u", 1, 2, 3, null); GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder(); extBuilder.setParams(param1); + XContentType xContentType = randomFrom(XContentType.values()); - BytesReference serialized = XContentHelper.toXContent(extBuilder, xContentType, true); + XContentBuilder builder = XContentBuilder.builder(xContentType.xContent()); + builder = extBuilder.toXContent(builder, EMPTY_PARAMS); + BytesReference serialized = BytesReference.bytes(builder); + XContentParser parser = createParser(xContentType.xContent(), serialized); + parser.nextToken(); GenerativeQAParamExtBuilder deserialized = GenerativeQAParamExtBuilder.parse(parser); + assertEquals(extBuilder, deserialized); } diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParametersTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParametersTests.java index e5caa70ed7..4835b764fe 100644 --- a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParametersTests.java +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParametersTests.java @@ -239,7 +239,18 @@ public void testToXConent() throws IOException { assertNotNull(parameters.toXContent(builder, null)); } - public void testToXConentAllOptionalParameters() throws IOException { + public void testToXContentEmptyParams() throws IOException { + GenerativeQAParameters parameters = new GenerativeQAParameters(); + XContent xc = mock(XContent.class); + OutputStream os = mock(OutputStream.class); + XContentGenerator generator = mock(XContentGenerator.class); + when(xc.createGenerator(any(), any(), any())).thenReturn(generator); + XContentBuilder builder = new XContentBuilder(xc, os); + parameters.toXContent(builder, null); + assertNotNull(parameters.toXContent(builder, null)); + } + + public void testToXContentAllOptionalParameters() throws IOException { String conversationId = "a"; String llmModel = "b"; String llmQuestion = "c";