diff --git a/backend/src/main/java/com/shyashyashya/refit/domain/qnaset/event/QuestionBatchEmbeddingEventHandler.java b/backend/src/main/java/com/shyashyashya/refit/domain/qnaset/event/QuestionBatchEmbeddingEventHandler.java index a66bf174..7b1e876e 100644 --- a/backend/src/main/java/com/shyashyashya/refit/domain/qnaset/event/QuestionBatchEmbeddingEventHandler.java +++ b/backend/src/main/java/com/shyashyashya/refit/domain/qnaset/event/QuestionBatchEmbeddingEventHandler.java @@ -80,7 +80,7 @@ private List> generateEmbeddings(List qnaSets) { for (int i = 0; i < qnaSets.size(); i++) { sendQnaSets.add(qnaSets.get(i)); if (sendQnaSets.size() == 100) { - var requests = new GeminiBatchEmbeddingRequest(qnaSets.stream() + var requests = new GeminiBatchEmbeddingRequest(sendQnaSets.stream() .map(qnaSet -> GeminiBatchEmbeddingRequest.GeminiEmbeddingRequest.of( qnaSet.getQuestionText(), GeminiEmbeddingRequest.TaskType.SEMANTIC_SIMILARITY, @@ -92,6 +92,15 @@ private List> generateEmbeddings(List qnaSets) { } } + var requests = new GeminiBatchEmbeddingRequest(sendQnaSets.stream() + .map(qnaSet -> GeminiBatchEmbeddingRequest.GeminiEmbeddingRequest.of( + qnaSet.getQuestionText(), + GeminiEmbeddingRequest.TaskType.SEMANTIC_SIMILARITY, + outputDimensionality)) + .toList()); + embeddings.addAll(geminiClient.sendAsyncBatchEmbeddingRequest(requests).join().embeddings().stream() + .toList()); + return embeddings.stream() .map(GeminiBatchEmbeddingResponse.Embedding::values) .toList(); diff --git a/backend/src/main/java/com/shyashyashya/refit/global/gemini/GeminiClient.java b/backend/src/main/java/com/shyashyashya/refit/global/gemini/GeminiClient.java index d8079f02..953191a2 100644 --- a/backend/src/main/java/com/shyashyashya/refit/global/gemini/GeminiClient.java +++ b/backend/src/main/java/com/shyashyashya/refit/global/gemini/GeminiClient.java @@ -69,6 +69,9 @@ public CompletableFuture sendAsyncEmbeddingRequest(Gemi public CompletableFuture sendAsyncBatchEmbeddingRequest( GeminiBatchEmbeddingRequest requestBody) { + log.info( + "[sendAsyncBatchEmbeddingRequest] send generate embedding in batch: request size {}", + requestBody.requests().size()); return webClient .post() .uri(EMBEDDING_BATCH_ENDPOINT)