Skip to content

Commit

Permalink
Add integration tests for the RAG pipeline covering OpenAI and Bedrock (
Browse files Browse the repository at this point in the history
opensearch-project#2213)

* Add integration tests for the RAG pipeline covering OpenAI and Bedrock.

Signed-off-by: Austin Lee <[email protected]>

* Use us-west-2 for testing.

Signed-off-by: Austin Lee <[email protected]>

* Fix spotless

Signed-off-by: Austin Lee <[email protected]>

* Fix broken import.

Signed-off-by: Austin Lee <[email protected]>

* Fix loading of default feature flag settings, update tests to pick up
changes in memory API.

Signed-off-by: Austin Lee <[email protected]>

* Remove unused code.

Signed-off-by: Austin Lee <[email protected]>

* Remove unused code.

Signed-off-by: Austin Lee <[email protected]>

---------

Signed-off-by: Austin Lee <[email protected]>
  • Loading branch information
austintlee authored Mar 19, 2024
1 parent 951dbcf commit d645c83
Show file tree
Hide file tree
Showing 9 changed files with 601 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,12 @@ public class MachineLearningPlugin extends Plugin implements ActionPlugin, Searc
private ScriptService scriptService;
private Encryptor encryptor;

public MachineLearningPlugin(Settings settings) {
// Handle this here as this feature is tied to Search/Query API, not to a ml-common API
// and as such, it can't be lazy-loaded when a ml-commons API is invoked.
this.ragSearchPipelineEnabled = MLCommonsSettings.ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED.get(settings);
}

@Override
public List<ActionHandler<? extends ActionRequest, ? extends ActionResponse>> getActions() {
return ImmutableList
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.opensearch.client.Client;
import org.opensearch.common.settings.Settings;
import org.opensearch.ml.common.spi.MLCommonsExtension;
import org.opensearch.ml.common.spi.tools.Tool;
import org.opensearch.ml.engine.tools.MLModelTool;
Expand All @@ -47,7 +48,7 @@

public class MachineLearningPluginTests {

MachineLearningPlugin plugin = new MachineLearningPlugin();
MachineLearningPlugin plugin = new MachineLearningPlugin(Settings.EMPTY);

@Rule
public ExpectedException exceptionRule = ExpectedException.none();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,10 @@ protected Response ingestIrisData(String indexName) throws IOException, ParseExc
TestHelper.toHttpEntity(TestData.IRIS_DATA.replaceAll("iris_data", indexName)),
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, ""))
);

statsResponse = TestHelper.makeRequest(client(), "GET", indexName, ImmutableMap.of(), "", null);
assertEquals(RestStatus.OK, TestHelper.restStatus(statsResponse));

assertEquals(RestStatus.OK, TestHelper.restStatus(statsResponse));
return bulkResponse;
}
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import java.util.function.Consumer;

import org.apache.commons.lang3.exception.ExceptionUtils;
import org.apache.hc.core5.http.HttpEntity;
import org.apache.hc.core5.http.HttpHeaders;
import org.apache.hc.core5.http.message.BasicHeader;
import org.junit.Before;
Expand All @@ -24,10 +25,10 @@

public class RestMLRemoteInferenceIT extends MLCommonsRestTestCase {

private final String OPENAI_KEY = System.getenv("OPENAI_KEY");
private final String COHERE_KEY = System.getenv("COHERE_KEY");
final String OPENAI_KEY = System.getenv("OPENAI_KEY");
final String COHERE_KEY = System.getenv("COHERE_KEY");

private final String completionModelConnectorEntity = "{\n"
final String completionModelConnectorEntity = "{\n"
+ "\"name\": \"OpenAI Connector\",\n"
+ "\"description\": \"The connector to public OpenAI model service for GPT 3.5\",\n"
+ "\"version\": 1,\n"
Expand Down Expand Up @@ -69,6 +70,8 @@ public class RestMLRemoteInferenceIT extends MLCommonsRestTestCase {
@Before
public void setup() throws IOException, InterruptedException {
disableClusterConnectorAccessControl();
// TODO Do we really need to wait this long? This adds 20s to every test case run.
// Can we instead check the cluster state and move on?
Thread.sleep(20000);
}

Expand Down Expand Up @@ -735,11 +738,11 @@ public void testCohereClassifyModel() throws IOException, InterruptedException {
assertFalse(responseList.isEmpty());
}

private Response createConnector(String input) throws IOException {
protected Response createConnector(String input) throws IOException {
return TestHelper.makeRequest(client(), "POST", "/_plugins/_ml/connectors/_create", null, TestHelper.toHttpEntity(input), null);
}

private Response registerRemoteModel(String name, String connectorId) throws IOException {
protected Response registerRemoteModel(String name, String connectorId) throws IOException {
String registerModelGroupEntity = "{\n"
+ " \"name\": \"remote_model_group\",\n"
+ " \"description\": \"This is an example description\"\n"
Expand Down Expand Up @@ -775,15 +778,15 @@ private Response registerRemoteModel(String name, String connectorId) throws IOE
.makeRequest(client(), "POST", "/_plugins/_ml/models/_register", null, TestHelper.toHttpEntity(registerModelEntity), null);
}

private Response deployRemoteModel(String modelId) throws IOException {
protected Response deployRemoteModel(String modelId) throws IOException {
return TestHelper.makeRequest(client(), "POST", "/_plugins/_ml/models/" + modelId + "/_deploy", null, "", null);
}

private Response predictRemoteModel(String modelId, String input) throws IOException {
protected Response predictRemoteModel(String modelId, String input) throws IOException {
return TestHelper.makeRequest(client(), "POST", "/_plugins/_ml/models/" + modelId + "/_predict", null, input, null);
}

private Response undeployRemoteModel(String modelId) throws IOException {
protected Response undeployRemoteModel(String modelId) throws IOException {
String undeployEntity = "{\n"
+ " \"SYqCMdsFTumUwoHZcsgiUg\": {\n"
+ " \"stats\": {\n"
Expand All @@ -796,13 +799,20 @@ private Response undeployRemoteModel(String modelId) throws IOException {
return TestHelper.makeRequest(client(), "POST", "/_plugins/_ml/models/" + modelId + "/_undeploy", null, undeployEntity, null);
}

private boolean checkThrottlingOpenAI(Map responseMap) {
protected boolean checkThrottlingOpenAI(Map responseMap) {
Map map = (Map) responseMap.get("error");
String message = (String) map.get("message");
return message.equals("You exceeded your current quota, please check your plan and billing details.");
}

private void disableClusterConnectorAccessControl() throws IOException {
protected Map parseResponseToMap(Response response) throws IOException {
HttpEntity entity = response.getEntity();
assertNotNull(response);
String entityString = TestHelper.httpEntityToString(entity);
return gson.fromJson(entityString, Map.class);
}

protected void disableClusterConnectorAccessControl() throws IOException {
Response response = TestHelper
.makeRequest(
client(),
Expand All @@ -815,7 +825,7 @@ private void disableClusterConnectorAccessControl() throws IOException {
assertEquals(200, response.getStatusLine().getStatusCode());
}

private Response getTask(String taskId) throws IOException {
protected Response getTask(String taskId) throws IOException {
return TestHelper.makeRequest(client(), "GET", "/_plugins/_ml/tasks/" + taskId, null, "", null);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"title" : "Abraham Lincoln 1", "text" : "Abraham Lincoln (/ˈlɪŋkən/ LINK-ən; February 12, 1809 – April 15, 1865) was an American lawyer, politician, and statesman who served as the 16th president of the United States from 1861 until his assassination in 1865. Lincoln led the Union through the American Civil War to defend the nation as a constitutional union and succeeded in abolishing slavery, bolstering the federal government, and modernizing the U.S. economy.\n"}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"title" : "Abraham Lincoln 2", "text" : "Abraham Lincoln was born on February 12, 1809, the second child of Thomas Lincoln and Nancy Hanks Lincoln, in a log cabin on Sinking Spring Farm near Hodgenville, Kentucky.[2] He was a descendant of Samuel Lincoln, an Englishman who migrated from Hingham, Norfolk, to its namesake, Hingham, Massachusetts, in 1638. The family then migrated west, passing through New Jersey, Pennsylvania, and Virginia.[3] Lincoln was also a descendant of the Harrison family of Virginia; his paternal grandfather and namesake, Captain Abraham Lincoln and wife Bathsheba (née Herring) moved the family from Virginia to Jefferson County, Kentucky.[b] The captain was killed in an Indian raid in 1786.[5] His children, including eight-year-old Thomas, Abraham's father, witnessed the attack.[6][c] Thomas then worked at odd jobs in Kentucky and Tennessee before the family settled in Hardin County, Kentucky, in the early 1800s.[6]\n"}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"title" : "Abraham Lincoln 3", "text" : "Lincoln was born into poverty in a log cabin in Kentucky and was raised on the frontier, primarily in Indiana. He was self-educated and became a lawyer, Whig Party leader, Illinois state legislator, and U.S. Congressman from Illinois. In 1849, he returned to his successful law practice in central Illinois. In 1854, he was angered by the Kansas–Nebraska Act, which opened the territories to slavery, and he re-entered politics. He soon became a leader of the new Republican Party. He reached a national audience in the 1858 Senate campaign debates against Stephen A. Douglas. Lincoln ran for president in 1860, sweeping the North to gain victory. Pro-slavery elements in the South viewed his election as a threat to slavery, and Southern states began seceding from the nation. During this time, the newly formed Confederate States of America began seizing federal military bases in the south. Just over one month after Lincoln assumed the presidency, the Confederate States attacked Fort Sumter, a U.S. fort in South Carolina. Following the bombardment, Lincoln mobilized forces to suppress the rebellion and restore the union.\n"}
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ public class GenerativeSearchResponse extends SearchResponse {
private static final String EXT_SECTION_NAME = "ext";
private static final String GENERATIVE_QA_ANSWER_FIELD_NAME = "answer";
private static final String GENERATIVE_QA_ERROR_FIELD_NAME = "error";
private static final String INTERACTION_ID_FIELD_NAME = "interaction_id";
private static final String INTERACTION_ID_FIELD_NAME = "message_id";

private final String answer;
private String errorMessage;
Expand Down

0 comments on commit d645c83

Please sign in to comment.