Skip to content

Commit c02430e

Browse files
CChuYongilayaperumalg
authored andcommitted
Add PgIdType based schema generation for PgVectorStore
Signed-off-by: CChuYong <[email protected]>
1 parent 49df625 commit c02430e

File tree

2 files changed

+45
-29
lines changed

2 files changed

+45
-29
lines changed

vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/pgvector/PgVectorStore.java

+16-3
Original file line numberDiff line numberDiff line change
@@ -420,7 +420,10 @@ public void afterPropertiesSet() {
420420
// Enable the PGVector, JSONB and UUID support.
421421
this.jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS vector");
422422
this.jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS hstore");
423-
this.jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS \"uuid-ossp\"");
423+
424+
if (this.idType == PgIdType.UUID) {
425+
this.jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS \"uuid-ossp\"");
426+
}
424427

425428
this.jdbcTemplate.execute(String.format("CREATE SCHEMA IF NOT EXISTS %s", this.getSchemaName()));
426429

@@ -431,12 +434,12 @@ public void afterPropertiesSet() {
431434

432435
this.jdbcTemplate.execute(String.format("""
433436
CREATE TABLE IF NOT EXISTS %s (
434-
id uuid DEFAULT uuid_generate_v4() PRIMARY KEY,
437+
id %s PRIMARY KEY,
435438
content text,
436439
metadata json,
437440
embedding vector(%d)
438441
)
439-
""", this.getFullyQualifiedTableName(), this.embeddingDimensions()));
442+
""", this.getFullyQualifiedTableName(), this.getColumnTypeName(), this.embeddingDimensions()));
440443

441444
if (this.createIndexMethod != PgIndexType.NONE) {
442445
this.jdbcTemplate.execute(String.format("""
@@ -466,6 +469,16 @@ private String getVectorIndexName() {
466469
return this.vectorIndexName;
467470
}
468471

472+
private String getColumnTypeName() {
473+
return switch (getIdType()) {
474+
case UUID -> "uuid DEFAULT uuid_generate_v4()";
475+
case TEXT -> "text";
476+
case INTEGER -> "integer";
477+
case SERIAL -> "serial";
478+
case BIGSERIAL -> "bigserial";
479+
};
480+
}
481+
469482
int embeddingDimensions() {
470483
// The manually set dimensions have precedence over the computed one.
471484
if (this.dimensions > 0) {

vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/pgvector/PgVectorStoreIT.java

+29-26
Original file line numberDiff line numberDiff line change
@@ -112,27 +112,6 @@ public static String getText(String uri) {
112112
}
113113
}
114114

115-
private static void initSchema(ApplicationContext context) {
116-
PgVectorStore vectorStore = context.getBean(PgVectorStore.class);
117-
JdbcTemplate jdbcTemplate = context.getBean(JdbcTemplate.class);
118-
// Enable the PGVector, JSONB and UUID support.
119-
jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS vector");
120-
jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS hstore");
121-
jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS \"uuid-ossp\"");
122-
123-
jdbcTemplate.execute(String.format("CREATE SCHEMA IF NOT EXISTS %s", PgVectorStore.DEFAULT_SCHEMA_NAME));
124-
125-
jdbcTemplate.execute(String.format("""
126-
CREATE TABLE IF NOT EXISTS %s.%s (
127-
id text PRIMARY KEY,
128-
content text,
129-
metadata json,
130-
embedding vector(%d)
131-
)
132-
""", PgVectorStore.DEFAULT_SCHEMA_NAME, PgVectorStore.DEFAULT_TABLE_NAME,
133-
vectorStore.embeddingDimensions()));
134-
}
135-
136115
private static void dropTable(ApplicationContext context) {
137116
JdbcTemplate jdbcTemplate = context.getBean(JdbcTemplate.class);
138117
jdbcTemplate.execute("DROP TABLE IF EXISTS vector_store");
@@ -218,21 +197,47 @@ public void testToPgTypeWithUuidIdType() {
218197
}
219198

220199
@Test
221-
public void testToPgTypeWithNonUuidIdType() {
200+
public void testToPgTypeWithTextIdType() {
222201
this.contextRunner.withPropertyValues("test.spring.ai.vectorstore.pgvector.distanceType=" + "COSINE_DISTANCE")
223-
.withPropertyValues("test.spring.ai.vectorstore.pgvector.initializeSchema=" + false)
224202
.withPropertyValues("test.spring.ai.vectorstore.pgvector.idType=" + "TEXT")
225203
.run(context -> {
226204

227205
VectorStore vectorStore = context.getBean(VectorStore.class);
228-
initSchema(context);
229206

230207
vectorStore.add(List.of(new Document("NOT_UUID", "TEXT", new HashMap<>())));
231208

232209
dropTable(context);
233210
});
234211
}
235212

213+
@Test
214+
public void testToPgTypeWithSerialIdType() {
215+
this.contextRunner.withPropertyValues("test.spring.ai.vectorstore.pgvector.distanceType=" + "COSINE_DISTANCE")
216+
.withPropertyValues("test.spring.ai.vectorstore.pgvector.idType=" + "SERIAL")
217+
.run(context -> {
218+
219+
VectorStore vectorStore = context.getBean(VectorStore.class);
220+
221+
vectorStore.add(List.of(new Document("1", "TEXT", new HashMap<>())));
222+
223+
dropTable(context);
224+
});
225+
}
226+
227+
@Test
228+
public void testToPgTypeWithBigSerialIdType() {
229+
this.contextRunner.withPropertyValues("test.spring.ai.vectorstore.pgvector.distanceType=" + "COSINE_DISTANCE")
230+
.withPropertyValues("test.spring.ai.vectorstore.pgvector.idType=" + "BIGSERIAL")
231+
.run(context -> {
232+
233+
VectorStore vectorStore = context.getBean(VectorStore.class);
234+
235+
vectorStore.add(List.of(new Document("1", "TEXT", new HashMap<>())));
236+
237+
dropTable(context);
238+
});
239+
}
240+
236241
@Test
237242
public void testBulkOperationWithUuidIdType() {
238243
this.contextRunner.withPropertyValues("test.spring.ai.vectorstore.pgvector.distanceType=" + "COSINE_DISTANCE")
@@ -256,11 +261,9 @@ public void testBulkOperationWithUuidIdType() {
256261
@Test
257262
public void testBulkOperationWithNonUuidIdType() {
258263
this.contextRunner.withPropertyValues("test.spring.ai.vectorstore.pgvector.distanceType=" + "COSINE_DISTANCE")
259-
.withPropertyValues("test.spring.ai.vectorstore.pgvector.initializeSchema=" + false)
260264
.withPropertyValues("test.spring.ai.vectorstore.pgvector.idType=" + "TEXT")
261265
.run(context -> {
262266
VectorStore vectorStore = context.getBean(VectorStore.class);
263-
initSchema(context);
264267

265268
List<Document> documents = List.of(new Document("NON_UUID_1", "TEXT", new HashMap<>()),
266269
new Document("NON_UUID_2", "TEXT", new HashMap<>()),

0 commit comments

Comments
 (0)