Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add PgIdType based schema generation for PgVectorStore #2463

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,10 @@ public void afterPropertiesSet() {
// Enable the PGVector, JSONB and UUID support.
this.jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS vector");
this.jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS hstore");
this.jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS \"uuid-ossp\"");

if (this.idType == PgIdType.UUID) {
this.jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS \"uuid-ossp\"");
}

this.jdbcTemplate.execute(String.format("CREATE SCHEMA IF NOT EXISTS %s", this.getSchemaName()));

Expand All @@ -431,12 +434,12 @@ public void afterPropertiesSet() {

this.jdbcTemplate.execute(String.format("""
CREATE TABLE IF NOT EXISTS %s (
id uuid DEFAULT uuid_generate_v4() PRIMARY KEY,
id %s PRIMARY KEY,
content text,
metadata json,
embedding vector(%d)
)
""", this.getFullyQualifiedTableName(), this.embeddingDimensions()));
""", this.getFullyQualifiedTableName(), this.getColumnTypeName(), this.embeddingDimensions()));

if (this.createIndexMethod != PgIndexType.NONE) {
this.jdbcTemplate.execute(String.format("""
Expand Down Expand Up @@ -466,6 +469,16 @@ private String getVectorIndexName() {
return this.vectorIndexName;
}

private String getColumnTypeName() {
return switch (getIdType()) {
case UUID -> "uuid DEFAULT uuid_generate_v4()";
case TEXT -> "text";
case INTEGER -> "integer";
case SERIAL -> "serial";
case BIGSERIAL -> "bigserial";
};
}

int embeddingDimensions() {
// The manually set dimensions have precedence over the computed one.
if (this.dimensions > 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,27 +112,6 @@ public static String getText(String uri) {
}
}

private static void initSchema(ApplicationContext context) {
PgVectorStore vectorStore = context.getBean(PgVectorStore.class);
JdbcTemplate jdbcTemplate = context.getBean(JdbcTemplate.class);
// Enable the PGVector, JSONB and UUID support.
jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS vector");
jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS hstore");
jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS \"uuid-ossp\"");

jdbcTemplate.execute(String.format("CREATE SCHEMA IF NOT EXISTS %s", PgVectorStore.DEFAULT_SCHEMA_NAME));

jdbcTemplate.execute(String.format("""
CREATE TABLE IF NOT EXISTS %s.%s (
id text PRIMARY KEY,
content text,
metadata json,
embedding vector(%d)
)
""", PgVectorStore.DEFAULT_SCHEMA_NAME, PgVectorStore.DEFAULT_TABLE_NAME,
vectorStore.embeddingDimensions()));
}

private static void dropTable(ApplicationContext context) {
JdbcTemplate jdbcTemplate = context.getBean(JdbcTemplate.class);
jdbcTemplate.execute("DROP TABLE IF EXISTS vector_store");
Expand Down Expand Up @@ -218,21 +197,47 @@ public void testToPgTypeWithUuidIdType() {
}

@Test
public void testToPgTypeWithNonUuidIdType() {
public void testToPgTypeWithTextIdType() {
this.contextRunner.withPropertyValues("test.spring.ai.vectorstore.pgvector.distanceType=" + "COSINE_DISTANCE")
.withPropertyValues("test.spring.ai.vectorstore.pgvector.initializeSchema=" + false)
.withPropertyValues("test.spring.ai.vectorstore.pgvector.idType=" + "TEXT")
.run(context -> {

VectorStore vectorStore = context.getBean(VectorStore.class);
initSchema(context);

vectorStore.add(List.of(new Document("NOT_UUID", "TEXT", new HashMap<>())));

dropTable(context);
});
}

@Test
public void testToPgTypeWithSerialIdType() {
this.contextRunner.withPropertyValues("test.spring.ai.vectorstore.pgvector.distanceType=" + "COSINE_DISTANCE")
.withPropertyValues("test.spring.ai.vectorstore.pgvector.idType=" + "SERIAL")
.run(context -> {

VectorStore vectorStore = context.getBean(VectorStore.class);

vectorStore.add(List.of(new Document("1", "TEXT", new HashMap<>())));

dropTable(context);
});
}

@Test
public void testToPgTypeWithBigSerialIdType() {
this.contextRunner.withPropertyValues("test.spring.ai.vectorstore.pgvector.distanceType=" + "COSINE_DISTANCE")
.withPropertyValues("test.spring.ai.vectorstore.pgvector.idType=" + "BIGSERIAL")
.run(context -> {

VectorStore vectorStore = context.getBean(VectorStore.class);

vectorStore.add(List.of(new Document("1", "TEXT", new HashMap<>())));

dropTable(context);
});
}

@Test
public void testBulkOperationWithUuidIdType() {
this.contextRunner.withPropertyValues("test.spring.ai.vectorstore.pgvector.distanceType=" + "COSINE_DISTANCE")
Expand All @@ -256,11 +261,9 @@ public void testBulkOperationWithUuidIdType() {
@Test
public void testBulkOperationWithNonUuidIdType() {
this.contextRunner.withPropertyValues("test.spring.ai.vectorstore.pgvector.distanceType=" + "COSINE_DISTANCE")
.withPropertyValues("test.spring.ai.vectorstore.pgvector.initializeSchema=" + false)
.withPropertyValues("test.spring.ai.vectorstore.pgvector.idType=" + "TEXT")
.run(context -> {
VectorStore vectorStore = context.getBean(VectorStore.class);
initSchema(context);

List<Document> documents = List.of(new Document("NON_UUID_1", "TEXT", new HashMap<>()),
new Document("NON_UUID_2", "TEXT", new HashMap<>()),
Expand Down