Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add integrated inference #181

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions src/integration/java/io/pinecone/clients/ConnectionsMapTest.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.pinecone.clients;

import io.pinecone.configs.PineconeConfig;
import io.pinecone.configs.PineconeConnection;
import io.pinecone.exceptions.PineconeNotFoundException;
import io.pinecone.helpers.RandomStringBuilder;
Expand Down Expand Up @@ -57,6 +58,10 @@ public void testMultipleIndexesWithMultipleClients() throws InterruptedException
// Get index1's host
String host1 = indexModel1.getHost();

// Create config1 for getting index connection and set the host
PineconeConfig config1 = new PineconeConfig(System.getenv("PINECONE_API_KEY"));
config1.setHost(host1);

// Create index-2
pinecone1.createServerlessIndex(indexName2,
null,
Expand All @@ -72,6 +77,11 @@ public void testMultipleIndexesWithMultipleClients() throws InterruptedException
// Get index2's host
String host2 = indexModel2.getHost();

// Create config2 for getting index connection and set the host
PineconeConfig config2 = new PineconeConfig(System.getenv("PINECONE_API_KEY"));
config1.setHost(host2);


// Establish grpc connection for index-1
Index index1_1 = pinecone1.getIndexConnection(indexName1);
// Get connections map
Expand All @@ -94,7 +104,7 @@ public void testMultipleIndexesWithMultipleClients() throws InterruptedException
assertEquals(host2, connectionsMap1_2.get(indexName2).toString());

// Establishing connections with index1 and index2 using another pinecone client
pinecone2.getConnection(indexName1);
pinecone2.getConnection(indexName1, config1);
ConcurrentHashMap<String, PineconeConnection> connectionsMap2_1 = pinecone1.getConnectionsMap();
// Verify the new connections map is pointing to the same reference
assert connectionsMap2_1 == connectionsMap1_2;
Expand All @@ -103,7 +113,7 @@ public void testMultipleIndexesWithMultipleClients() throws InterruptedException
// Verify the connection value for index1 is host1
assertEquals(host1, connectionsMap2_1.get(indexName1).toString());

pinecone2.getConnection(indexName2);
pinecone2.getConnection(indexName2, config2);
ConcurrentHashMap<String, PineconeConnection> connectionsMap2_2 = pinecone1.getConnectionsMap();
// Verify the new connections map is pointing to the same reference
assert connectionsMap2_1 == connectionsMap2_2;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,12 @@ public void createPodIndexWithDeletionProtectionDisabled() {
Map<String, String> actualTags = indexModel.getTags();
Assertions.assertEquals(expectedTags, actualTags);
// Configure index to enable deletionProtection
controlPlaneClient.configureServerlessIndex(indexName, DeletionProtection.ENABLED, expectedTags);
controlPlaneClient.configureServerlessIndex(indexName, DeletionProtection.ENABLED, expectedTags, null);
indexModel = controlPlaneClient.describeIndex(indexName);
deletionProtection = indexModel.getDeletionProtection();
Assertions.assertEquals(deletionProtection, DeletionProtection.ENABLED);
// Configure index to disable deletionProtection
controlPlaneClient.configureServerlessIndex(indexName, DeletionProtection.DISABLED, expectedTags);
controlPlaneClient.configureServerlessIndex(indexName, DeletionProtection.DISABLED, expectedTags, null);
// Delete index
controlPlaneClient.deleteIndex(indexName);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ public void configureSparseIndex() throws InterruptedException {
waitUntilIndexIsReady(pinecone, indexName, 200000);

// Disable deletion protection and add more index tags
pinecone.configureServerlessIndex(indexName, DeletionProtection.DISABLED, tags);
pinecone.configureServerlessIndex(indexName, DeletionProtection.DISABLED, tags, null);
Thread.sleep(7000);

// Describe index to confirm deletion protection is disabled
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ public static void setUp() throws IOException, InterruptedException {
when(connectionMock.getBlockingStub()).thenReturn(stubMock);
when(connectionMock.getAsyncStub()).thenReturn(asyncStubMock);

index = new Index(connectionMock, "some-index-name");
index = new Index(config, connectionMock, "some-index-name");
asyncIndex = new AsyncIndex(config, connectionMock, "some-index-name");
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
package io.pinecone.integration.dataPlane;

import io.pinecone.clients.Index;
import io.pinecone.clients.Pinecone;
import io.pinecone.helpers.RandomStringBuilder;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.openapitools.db_control.client.model.CreateIndexForModelRequest;
import org.openapitools.db_control.client.model.CreateIndexForModelRequestEmbed;
import org.openapitools.db_control.client.model.DeletionProtection;
import org.openapitools.db_data.client.ApiException;
import org.openapitools.db_data.client.model.SearchRecordsRequestQuery;
import org.openapitools.db_data.client.model.SearchRecordsRequestRerank;
import org.openapitools.db_data.client.model.SearchRecordsResponse;

import java.util.*;

public class UpsertAndSearchRecordsTest {
@Test
public void upsertAndSearchRecordsTest() throws ApiException, org.openapitools.db_control.client.ApiException, InterruptedException {
Pinecone pinecone = new Pinecone.Builder(System.getenv("PINECONE_API_KEY")).build();
String indexName = RandomStringBuilder.build("inf", 8);
HashMap<String, String> fieldMap = new HashMap<>();
fieldMap.put("text", "chunk_text");
CreateIndexForModelRequestEmbed embed = new CreateIndexForModelRequestEmbed()
.model("multilingual-e5-large")
.fieldMap(fieldMap);
pinecone.createIndexForModel(indexName, CreateIndexForModelRequest.CloudEnum.AWS, "us-west-2", embed, DeletionProtection.DISABLED, new HashMap<>());

// Wait for index to be created
Thread.sleep(10000);

Index index = pinecone.getIndexConnection(indexName);
ArrayList<Map<String, String>> upsertRecords = new ArrayList<>();

HashMap<String, String> record1 = new HashMap<>();
record1.put("_id", "rec1");
record1.put("category", "digestive system");
record1.put("chunk_text", "Apples are a great source of dietary fiber, which supports digestion and helps maintain a healthy gut.");

HashMap<String, String> record2 = new HashMap<>();
record2.put("_id", "rec2");
record2.put("category", "cultivation");
record2.put("chunk_text", "Apples originated in Central Asia and have been cultivated for thousands of years, with over 7,500 varieties available today.");

HashMap<String, String> record3 = new HashMap<>();
record3.put("_id", "rec3");
record3.put("category", "immune system");
record3.put("chunk_text", "Rich in vitamin C and other antioxidants, apples contribute to immune health and may reduce the risk of chronic diseases.");

HashMap<String, String> record4 = new HashMap<>();
record4.put("_id", "rec4");
record4.put("category", "endocrine system");
record4.put("chunk_text", "The high fiber content in apples can also help regulate blood sugar levels, making them a favorable snack for people with diabetes.");

upsertRecords.add(record1);
upsertRecords.add(record2);
upsertRecords.add(record3);
upsertRecords.add(record4);

index.upsertRecords("example-namespace", upsertRecords);

String namespace = "example-namespace";
HashMap<String, String> inputsMap = new HashMap<>();
inputsMap.put("text", "Disease prevention");
SearchRecordsRequestQuery query = new SearchRecordsRequestQuery()
.topK(4)
.inputs(inputsMap);

List<String> fields = new ArrayList<>();
fields.add("category");
fields.add("chunk_text");

// Wait for vectors to be upserted
Thread.sleep(5000);

SearchRecordsResponse recordsResponse = index.searchRecords(namespace, query, fields, null);
Assertions.assertEquals(upsertRecords.size(), recordsResponse.getResult().getHits().size());
Assertions.assertEquals(record3.get("_id"), recordsResponse.getResult().getHits().get(0).getId());

recordsResponse = index.searchRecordsById(record1.get("_id"), namespace, fields, 1, null, null);
Assertions.assertEquals(1, recordsResponse.getResult().getHits().size());
Assertions.assertEquals(record1.get("_id"), recordsResponse.getResult().getHits().get(0).getId());

SearchRecordsRequestRerank rerank = new SearchRecordsRequestRerank()
.model("bge-reranker-v2-m3")
.topN(2)
.rankFields(Arrays.asList("chunk_text"));

recordsResponse = index.searchRecordsByText("Disease prevention", namespace, fields, 4, null, rerank);
Assertions.assertEquals(record3.get("_id"), recordsResponse.getResult().getHits().get(0).getId());

pinecone.deleteIndex(indexName);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public static void setUp() throws IOException, InterruptedException {
when(connectionMock.getBlockingStub()).thenReturn(stubMock);
when(connectionMock.getAsyncStub()).thenReturn(asyncStubMock);

index = new Index(connectionMock, "some-index-name");
index = new Index(config, connectionMock, "some-index-name");
asyncIndex = new AsyncIndex(config, connectionMock, "some-index-name");
}

Expand Down
5 changes: 3 additions & 2 deletions src/main/java/io/pinecone/clients/AsyncIndex.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@

import static io.pinecone.clients.Pinecone.buildOkHttpClient;


/**
* A client for interacting with a Pinecone index via GRPC asynchronously. Allows for upserting, querying, fetching, updating, and deleting vectors.
* A client for interacting with a Pinecone index asynchronously. Allows for vector operations such as upserting,
* querying, fetching, updating, and deleting vectors along with records operations such as upsert and search records.
* This class provides a direct interface to interact with a specific index, encapsulating network communication and request validation.
* <p>
* Example:
Expand Down Expand Up @@ -70,6 +70,7 @@ public class AsyncIndex implements IndexInterface<ListenableFuture<UpsertRespons
* AsyncIndex asyncIndex = client.getAsyncIndexConnection("my-index");
* }</pre>
*
* @param config The {@link PineconeConfig} configuration of the index.
* @param connection The {@link PineconeConnection} configuration to be used for this index.
* @param indexName The name of the index to interact with. The index host will be automatically resolved.
* @throws PineconeValidationException if the connection object is null.
Expand Down
Loading
Loading