diff --git a/src/main/java/com/techfork/domain/recommendation/config/RecommendationProperties.java b/src/main/java/com/techfork/domain/recommendation/config/RecommendationProperties.java index f85e398..2505bf8 100644 --- a/src/main/java/com/techfork/domain/recommendation/config/RecommendationProperties.java +++ b/src/main/java/com/techfork/domain/recommendation/config/RecommendationProperties.java @@ -15,9 +15,11 @@ @ConfigurationProperties(prefix = "recommendation") public class RecommendationProperties { - private Integer knnSearchSize = 100; + private Integer knnSearchSize = 80; - private Integer numCandidates = 200; + private Integer numCandidates = 180; + + private Integer mmrCandidateSize = 80; private Integer mmrFinalSize = 30; @@ -34,9 +36,9 @@ public class RecommendationProperties { @NoArgsConstructor @AllArgsConstructor public static class EmbeddingWeights { - private Float title = 0.5f; - private Float summary = 0.5f; - private Float content = 0.0f; + private Float title = 0.4f; + private Float summary = 0.4f; + private Float content = 0.2f; } @Getter diff --git a/src/main/java/com/techfork/domain/recommendation/service/LlmRecommendationService.java b/src/main/java/com/techfork/domain/recommendation/service/LlmRecommendationService.java index 6a0ed90..4161aed 100644 --- a/src/main/java/com/techfork/domain/recommendation/service/LlmRecommendationService.java +++ b/src/main/java/com/techfork/domain/recommendation/service/LlmRecommendationService.java @@ -1,7 +1,6 @@ package com.techfork.domain.recommendation.service; import co.elastic.clients.elasticsearch.ElasticsearchClient; -import co.elastic.clients.elasticsearch._types.FieldValue; import co.elastic.clients.elasticsearch._types.KnnSearch; import co.elastic.clients.elasticsearch._types.query_dsl.Query; import co.elastic.clients.elasticsearch.core.SearchResponse; @@ -21,27 +20,27 @@ import com.techfork.domain.user.document.UserProfileDocument; import com.techfork.domain.user.entity.User; import com.techfork.domain.user.repository.UserProfileDocumentRepository; +import com.techfork.global.util.RrfScorer; import com.techfork.global.util.TimeDecayStrategy; import com.techfork.global.util.VectorUtil; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; +import org.springframework.context.annotation.Primary; import org.springframework.data.domain.PageRequest; import org.springframework.stereotype.Service; import org.springframework.transaction.annotation.Transactional; import java.io.IOException; import java.util.*; +import java.util.concurrent.CompletableFuture; import java.util.stream.Collectors; /** * MMR 알고리즘 기반 추천 전략 구현 - * - Elasticsearch k-NN 검색으로 초기 후보군 수집 - * - MMR 알고리즘으로 다양성 보장 - * - 읽은 글 제외 필터링 (Pre-filtering) - * - 시간 감쇠 가중치 적용 (최신 게시글 우선) */ @Slf4j @Service +@Primary @Transactional @RequiredArgsConstructor public class LlmRecommendationService implements RecommendationService { @@ -66,7 +65,6 @@ public class LlmRecommendationService implements RecommendationService { public int generateRecommendationsForUser(User user) { log.info("사용자 {} 추천 생성 시작", user.getId()); - // 1. 사용자 프로필 벡터 조회 Optional profileOpt = userProfileDocumentRepository.findByUserId(user.getId()); if (profileOpt.isEmpty() || profileOpt.get().getProfileVector() == null) { log.warn("사용자 {}의 프로필 또는 벡터를 찾을 수 없음. 추천 생성 스킵.", user.getId()); @@ -85,14 +83,11 @@ public int generateRecommendationsForUser(User user) { return 0; } - log.info("사용자 {} 추천 후보 {} 개 발견", user.getId(), candidates.size()); - // 3. MMR 적용하여 최종 추천 선택 List mmrResults = mmrService.applyMmr(candidates); - // 4. 기존 추천을 이력으로 보관 (오늘 생성된 추천 포함) + // 4. 기존 추천을 이력으로 보관 List oldRecommendations = recommendedPostRepository.findByUserOrderByRankAsc(user); - if (!oldRecommendations.isEmpty()) { List histories = oldRecommendations.stream() .map(RecommendationHistory::fromRecommendedPost) @@ -105,20 +100,12 @@ public int generateRecommendationsForUser(User user) { List recommendations = new ArrayList<>(); for (MmrResult result : mmrResults) { Post post = postRepository.getReferenceById(result.getPostId()); - RecommendedPost recommendedPost = RecommendedPost.create( - user, - post, - result.getSimilarityScore(), - result.getMmrScore(), - result.getRank() - ); - recommendations.add(recommendedPost); + recommendations.add(RecommendedPost.create( + user, post, result.getSimilarityScore(), result.getMmrScore(), result.getRank() + )); } recommendedPostRepository.saveAll(recommendations); - - log.info("사용자 {} 추천 생성 완료: {} 개", user.getId(), recommendations.size()); - return recommendations.size(); } catch (Exception e) { @@ -127,219 +114,129 @@ public int generateRecommendationsForUser(User user) { } } - /** - * 추천 생성 (평가 전용 - Train/Test Split 지원) - * 특정 읽은 글 목록(Train Set)만 제외하고 추천 생성 - * - * @param user 사용자 - * @param trainPostIds Train Set 게시글 ID 목록 (제외할 글) - * @return 추천된 게시글 ID 리스트 - */ - public List generateRecommendationsForEvaluation(User user, Set trainPostIds) { - // 1. 사용자 프로필 벡터 조회 - Optional profileOpt = userProfileDocumentRepository.findByUserId(user.getId()); - if (profileOpt.isEmpty() || profileOpt.get().getProfileVector() == null) { - log.warn("사용자 {}의 프로필 또는 벡터를 찾을 수 없음. 추천 생성 스킵.", user.getId()); - return Collections.emptyList(); - } - - float[] userProfileVector = profileOpt.get().getProfileVector(); - - try { - // 2. k-NN 검색으로 초기 후보군 가져오기 (Train Set만 제외) - List candidates = searchCandidatesWithCustomReadHistory(userProfileVector, user, trainPostIds); - - if (candidates.isEmpty()) { - log.debug("사용자 {}의 추천 후보군을 찾을 수 없음 (Train Set {} 개 제외)", user.getId(), trainPostIds.size()); - return Collections.emptyList(); - } - - // 3. MMR 적용하여 최종 추천 선택 - List mmrResults = mmrService.applyMmr(candidates); - - // 4. 추천된 게시글 ID 리스트 반환 - return mmrResults.stream() - .map(MmrResult::getPostId) - .toList(); - - } catch (Exception e) { - log.error("사용자 {} 추천 생성 실패 (Train/Test Split 평가용)", user.getId(), e); - return Collections.emptyList(); - } - } - - /** - * Elasticsearch 네이티브 k-NN 검색으로 초기 후보군 조회 (커스텀 읽은 글 목록) - * Train/Test Split 평가를 위해 Train Set만 제외 - */ - private List searchCandidatesWithCustomReadHistory( - float[] userProfileVector, - User user, - Set readPostIds) throws IOException { + private List searchCandidates(float[] userProfileVector, User user) throws IOException { + Set readPostIds = readPostRepository.findRecentReadPostsByUserIdWithMinDuration(user.getId(), PageRequest.of(0, 1000)) + .stream() + .map(readPost -> readPost.getPost().getId()) + .collect(Collectors.toSet()); - log.debug("사용자 {}의 읽은 게시글 {} 개 제외 (Train Set)", user.getId(), readPostIds.size()); + Optional profileOpt = userProfileDocumentRepository.findByUserId(user.getId()); + List keyKeywords = profileOpt.map(UserProfileDocument::getKeyKeywords).orElse(List.of()); - // 가중치 가져오기 RecommendationProperties.EmbeddingWeights weights = properties.getEmbeddingWeights(); + Query filterQuery = vectorQueryBuilder.createExcludeFilter(readPostIds); - // 랜덤 시드 생성 (현재 시간 기반) - long randomSeed = System.currentTimeMillis(); - double randomWeight = 0.0; - - // 1. 읽은 글 제외 필터 쿼리 생성 (Pre-filtering) - Query filterQuery = createExcludeFilter(readPostIds); - - // 2. 네이티브 k-NN 검색 객체 리스트 생성 (Title + Summary + Content) + // 1. kNN 검색 쿼리 준비 List knnSearches = vectorQueryBuilder.createKnnSearches( - TITLE_EMBEDDING_FIELD, - SUMMARY_EMBEDDING_FIELD, - CONTENT_CHUNKS_EMBEDDING_FIELD, - userProfileVector, - weights.getTitle(), - weights.getSummary(), - weights.getContent(), - properties.getKnnSearchSize(), - properties.getNumCandidates(), - filterQuery + TITLE_EMBEDDING_FIELD, SUMMARY_EMBEDDING_FIELD, CONTENT_CHUNKS_EMBEDDING_FIELD, + userProfileVector, weights.getTitle(), weights.getSummary(), weights.getContent(), + properties.getKnnSearchSize(), properties.getNumCandidates(), filterQuery ); - // 3. 랜덤 요소 추가 (function_score) - Query randomQuery = vectorQueryBuilder.createRandomScoreQuery(randomSeed, randomWeight); - - log.debug("ES k-NN 검색 실행 (Train/Test Split) - 가중치 [title:{}, summary:{}], 랜덤가중치: {}", - weights.getTitle(), weights.getSummary(), randomWeight); - - long startTime = System.currentTimeMillis(); - SearchResponse response = elasticsearchClient.search(s -> s - .index(POSTS_INDEX) - .knn(knnSearches) // k-NN 검색 (관련성 + 필터링) - .query(randomQuery) // 랜덤 점수 추가 - .size(properties.getKnnSearchSize()) - , - PostDocument.class + // 2. BM25 검색 쿼리 준비 + Query bm25Query = vectorQueryBuilder.createBm25Query( + keyKeywords, weights.getTitle(), weights.getSummary(), weights.getContent() ); - long duration = System.currentTimeMillis() - startTime; - log.info("추천 후보군 검색 완료 (Evaluation): {} 개, 소요 시간: {}ms", response.hits().hits().size(), duration); - // 결과를 MmrCandidate로 변환 - return response.hits().hits().stream() - .filter(hit -> hit.source() != null) - .map(this::mapToMmrCandidate) - .filter(candidate -> candidate.getSummaryVector() != null) - .toList(); - } + long startTime = System.currentTimeMillis(); - /** - * Elasticsearch 네이티브 k-NN 검색으로 초기 후보군 조회 - * - 이미 읽은 글 제외 - * - 랜덤 시드를 사용하여 매번 다른 후보군 생성 - */ - private List searchCandidates(float[] userProfileVector, User user) throws IOException { - // 이미 읽은 글 ID 목록 - Set readPostIds = readPostRepository.findRecentReadPostsByUserIdWithMinDuration(user.getId(), PageRequest.of(0, 1000)) - .stream() - .map(readPost -> readPost.getPost().getId()) - .collect(Collectors.toSet()); + // 3. kNN과 BM25 검색 병렬 실행 + CompletableFuture>> vectorSearchFuture = CompletableFuture.supplyAsync(() -> { + try { + SearchResponse response = elasticsearchClient.search(s -> s + .index(POSTS_INDEX) + .knn(knnSearches) + .size(properties.getKnnSearchSize()), + PostDocument.class + ); + return response.hits().hits(); + } catch (IOException e) { + log.error("kNN 검색 실패", e); + return Collections.emptyList(); + } + }); - log.debug("사용자 {}의 읽은 게시글 {} 개 제외", user.getId(), readPostIds.size()); + CompletableFuture>> keywordSearchFuture = CompletableFuture.supplyAsync(() -> { + // 키워드가 없으면 BM25 검색 생략 + if (bm25Query == null) { + log.debug("키워드가 없어 BM25 검색 생략"); + return Collections.emptyList(); + } + try { + SearchResponse response = elasticsearchClient.search(s -> s + .index(POSTS_INDEX) + .query(q -> q.bool(b -> { + b.must(bm25Query); + if (filterQuery != null) b.filter(filterQuery); + return b; + })) + .size(properties.getKnnSearchSize()), + PostDocument.class + ); + return response.hits().hits(); + } catch (IOException e) { + log.error("BM25 검색 실패", e); + return Collections.emptyList(); + } + }); - // 가중치 가져오기 - RecommendationProperties.EmbeddingWeights weights = properties.getEmbeddingWeights(); + // 4. 두 검색 완료 대기 + CompletableFuture allSearches = CompletableFuture.allOf(vectorSearchFuture, keywordSearchFuture); + allSearches.join(); - // 랜덤 시드 생성 (현재 시간 기반) - long randomSeed = System.currentTimeMillis(); - double randomWeight = 0.0; // 랜덤 가중치 20% + List> vectorHits = vectorSearchFuture.join(); + List> keywordHits = keywordSearchFuture.join(); - // 1. 읽은 글 제외 필터 쿼리 생성 (Pre-filtering) - Query filterQuery = createExcludeFilter(readPostIds); + log.info("후보군 검색 완료: kNN {} 개, BM25 {} 개, 소요 시간: {}ms", + vectorHits.size(), keywordHits.size(), System.currentTimeMillis() - startTime); - // 2. 네이티브 k-NN 검색 객체 리스트 생성 (Title + Summary + Content) - List knnSearches = vectorQueryBuilder.createKnnSearches( - TITLE_EMBEDDING_FIELD, - SUMMARY_EMBEDDING_FIELD, - CONTENT_CHUNKS_EMBEDDING_FIELD, - userProfileVector, - weights.getTitle(), - weights.getSummary(), - weights.getContent(), - properties.getKnnSearchSize(), - properties.getNumCandidates(), - filterQuery - ); - - // 3. 랜덤 요소 추가 (function_score) - Query randomQuery = vectorQueryBuilder.createRandomScoreQuery(randomSeed, randomWeight); - - log.debug("ES k-NN 검색 실행 - 가중치 [title:{}, summary:{}], 랜덤시드: {}, 랜덤가중치: {}", - weights.getTitle(), weights.getSummary(), randomSeed, randomWeight); + // 5. RRF로 결합 + return applyRrf(vectorHits, keywordHits); + } - long startTime = System.currentTimeMillis(); - SearchResponse response = elasticsearchClient.search(s -> s - .index(POSTS_INDEX) - .knn(knnSearches) // k-NN 검색 (관련성 + 필터링) - .query(randomQuery) // 랜덤 점수 추가 - .size(properties.getKnnSearchSize()) - , - PostDocument.class - ); - long duration = System.currentTimeMillis() - startTime; - log.info("추천 후보군 검색 완료: {} 개, 소요 시간: {}ms", response.hits().hits().size(), duration); + protected List applyRrf(List> vectorHits, List> keywordHits) { + // Post ID 리스트 추출 (null 체크) + List vectorPostIds = vectorHits.stream() + .filter(hit -> hit.source() != null) + .map(hit -> hit.source().getPostId()) + .toList(); - // 결과를 MmrCandidate로 변환 - return response.hits().hits().stream() + List keywordPostIds = keywordHits.stream() .filter(hit -> hit.source() != null) - .map(this::mapToMmrCandidate) - .filter(candidate -> candidate.getSummaryVector() != null) + .map(hit -> hit.source().getPostId()) .toList(); - } - /** - * 읽은 글 제외를 위한 필터 쿼리 생성 - */ - private Query createExcludeFilter(Set readPostIds) { - if (readPostIds == null || readPostIds.isEmpty()) { - return null; - } + // RRF 스코어 계산 + Map rrfScores = RrfScorer.calculateRrfScores(vectorPostIds, keywordPostIds); - List excludeValues = readPostIds.stream() - .map(FieldValue::of) + // Hit을 postId 기준으로 맵핑 + Map> hitMap = new HashMap<>(); + vectorHits.stream() + .filter(hit -> hit.source() != null) + .forEach(hit -> hitMap.putIfAbsent(hit.source().getPostId(), hit)); + keywordHits.stream() + .filter(hit -> hit.source() != null) + .forEach(hit -> hitMap.putIfAbsent(hit.source().getPostId(), hit)); + + // RRF 스코어 순으로 정렬하여 MMR Candidate 생성 + // MMR 성능을 위해 상위 N개만 선택 (MMR은 O(n²)이므로 후보 수 제한 필요) + return rrfScores.entrySet().stream() + .sorted(Map.Entry.comparingByValue().reversed()) + .limit(properties.getMmrCandidateSize()) + .map(entry -> mapToMmrCandidate(hitMap.get(entry.getKey()), entry.getValue())) + .filter(candidate -> candidate.getSummaryVector() != null) .toList(); - - return Query.of(q -> q - .bool(b -> b - .mustNot(mn -> mn - .terms(t -> t - .field("postId") - .terms(v -> v.value(excludeValues)) - ) - ) - ) - ); } - /** - * PostDocument를 MmrCandidate로 변환 - * 시간 감쇠 가중치를 유사도 점수에 적용 - */ - private MmrCandidate mapToMmrCandidate(Hit hit) { + protected MmrCandidate mapToMmrCandidate(Hit hit, double rrfScore) { PostDocument doc = hit.source(); - double score = Objects.requireNonNullElse(hit.score(), 0.0); - - // 시간 감쇠 가중치 적용 double timeDecayWeight = timeDecayStrategy.calculateWeight(Objects.requireNonNull(doc).getPublishedAt()); - double adjustedScore = score * timeDecayWeight; - - log.trace("게시글 {} 점수 조정: 원본={}, 시간가중치={}, 최종={}", - doc.getPostId(), score, timeDecayWeight, adjustedScore); - - float[] titleVector = VectorUtil.convertToFloatArray(doc.getTitleEmbedding()); - float[] summaryVector = VectorUtil.convertToFloatArray(doc.getSummaryEmbedding()); return MmrCandidate.builder() .postId(doc.getPostId()) - .titleVector(titleVector) - .summaryVector(summaryVector) - .similarityScore(adjustedScore) + .titleVector(VectorUtil.convertToFloatArray(doc.getTitleEmbedding())) + .summaryVector(VectorUtil.convertToFloatArray(doc.getSummaryEmbedding())) + .similarityScore(rrfScore * timeDecayWeight) .build(); } } diff --git a/src/main/java/com/techfork/domain/recommendation/service/MmrService.java b/src/main/java/com/techfork/domain/recommendation/service/MmrService.java index af72fa6..90565d7 100644 --- a/src/main/java/com/techfork/domain/recommendation/service/MmrService.java +++ b/src/main/java/com/techfork/domain/recommendation/service/MmrService.java @@ -10,6 +10,7 @@ import java.util.ArrayList; import java.util.List; +import java.util.Random; /** * MMR (Maximal Marginal Relevance) 알고리즘 구현 @@ -21,6 +22,7 @@ public class MmrService { private final RecommendationProperties properties; + private final Random random = new Random(); @Getter @Builder @@ -40,6 +42,18 @@ public static class MmrResult { private int rank; } + private static class ScoredCandidate { + MmrCandidate candidate; + double mmrScore; + int originalIndex; + + ScoredCandidate(MmrCandidate candidate, double mmrScore, int originalIndex) { + this.candidate = candidate; + this.mmrScore = mmrScore; + this.originalIndex = originalIndex; + } + } + /** * MMR 알고리즘을 적용하여 다양성을 보장하는 추천 결과 생성 * @@ -61,8 +75,10 @@ public List applyMmr(List candidates) { log.debug("MMR 선택 시작: candidates={}, finalSize={}, lambda={}", candidates.size(), finalSize, lambda); - // 첫 번째는 가장 유사도가 높은 문서 선택 - MmrCandidate first = remainingCandidates.remove(0); + // 첫 번째는 상위 K개 중에서 랜덤하게 선택 (다양성 증가) + int topK = Math.min(5, remainingCandidates.size()); + int randomIndex = random.nextInt(topK); + MmrCandidate first = remainingCandidates.remove(randomIndex); selectedResults.add(MmrResult.builder() .postId(first.getPostId()) .similarityScore(first.getSimilarityScore()) @@ -70,32 +86,31 @@ public List applyMmr(List candidates) { .rank(1) .build()); - // 나머지 문서들을 MMR 점수 기반으로 선택 + // 나머지 문서들을 MMR 점수 기반으로 선택 (Top-K 샘플링으로 랜덤성 추가) while (selectedResults.size() < finalSize && !remainingCandidates.isEmpty()) { - MmrCandidate bestCandidate = null; - double bestMmrScore = Double.NEGATIVE_INFINITY; - int bestIndex = -1; - + // 모든 후보의 MMR 점수 계산 + List scoredCandidates = new ArrayList<>(); for (int i = 0; i < remainingCandidates.size(); i++) { MmrCandidate candidate = remainingCandidates.get(i); double mmrScore = calculateMmrScore(candidate, selectedResults, lambda, candidates); - - if (mmrScore > bestMmrScore) { - bestMmrScore = mmrScore; - bestCandidate = candidate; - bestIndex = i; - } + scoredCandidates.add(new ScoredCandidate(candidate, mmrScore, i)); } - if (bestCandidate != null) { - remainingCandidates.remove(bestIndex); - selectedResults.add(MmrResult.builder() - .postId(bestCandidate.getPostId()) - .similarityScore(bestCandidate.getSimilarityScore()) - .mmrScore(bestMmrScore) - .rank(selectedResults.size() + 1) - .build()); - } + // MMR 점수 내림차순 정렬 + scoredCandidates.sort((a, b) -> Double.compare(b.mmrScore, a.mmrScore)); + + // 상위 K개 중에서 랜덤 선택 + int topKForSelection = Math.min(3, scoredCandidates.size()); + int randomIdx = random.nextInt(topKForSelection); + ScoredCandidate selected = scoredCandidates.get(randomIdx); + + remainingCandidates.remove(selected.originalIndex); + selectedResults.add(MmrResult.builder() + .postId(selected.candidate.getPostId()) + .similarityScore(selected.candidate.getSimilarityScore()) + .mmrScore(selected.mmrScore) + .rank(selectedResults.size() + 1) + .build()); } log.info("MMR 선택 완료: 전체 {} 후보 중 {} 개 선택", diff --git a/src/main/java/com/techfork/domain/search/service/SearchServiceImpl.java b/src/main/java/com/techfork/domain/search/service/SearchServiceImpl.java index ab7ca46..7a465ca 100644 --- a/src/main/java/com/techfork/domain/search/service/SearchServiceImpl.java +++ b/src/main/java/com/techfork/domain/search/service/SearchServiceImpl.java @@ -14,6 +14,7 @@ import com.techfork.domain.user.document.UserProfileDocument; import com.techfork.domain.user.repository.UserProfileDocumentRepository; import com.techfork.global.llm.EmbeddingClient; +import com.techfork.global.util.RrfScorer; import com.techfork.global.util.VectorUtil; import java.io.IOException; import java.util.ArrayList; @@ -245,21 +246,29 @@ private KnnSearch createKnnSearch(String field, List vector, int k, int n } private List calculateRRF(List> lexicalHits, List> semanticHits) { - Map lexicalRankMap = new HashMap<>(); - AtomicInteger rank = new AtomicInteger(1); - lexicalHits.forEach(hit -> lexicalRankMap.put(hit.id(), rank.getAndIncrement())); - - Map semanticRankMap = new HashMap<>(); - rank.set(1); - semanticHits.forEach(hit -> semanticRankMap.put(hit.id(), rank.getAndIncrement())); - - Map combinedResults = new HashMap<>(); - Map rrfScores = new HashMap<>(); - - processHitsForRRF(lexicalHits, lexicalRankMap, rrfScores, combinedResults); - processHitsForRRF(semanticHits, semanticRankMap, rrfScores, combinedResults); + // Hit ID 리스트 추출 + List lexicalIds = lexicalHits.stream().map(Hit::id).toList(); + List semanticIds = semanticHits.stream().map(Hit::id).toList(); + + // RRF 스코어 계산 + Map rrfScores = RrfScorer.calculateRrfScores(lexicalIds, semanticIds); + + // Hit을 docId 기준으로 맵핑 (semantic 우선 - 벡터 포함 보장) + Map> hitMap = new HashMap<>(); + lexicalHits.forEach(hit -> hitMap.put(hit.id(), hit)); + semanticHits.forEach(hit -> hitMap.put(hit.id(), hit)); // semantic 결과로 덮어쓰기 (벡터 포함) + + // SearchResult로 변환 + Map resultMap = new HashMap<>(); + for (Map.Entry> entry : hitMap.entrySet()) { + String docId = entry.getKey(); + Hit hit = entry.getValue(); + SearchResult result = mapToSearchResult(hit); + resultMap.put(docId, result); + } - return combinedResults.values().stream() + // 최종 스코어 적용 및 정렬 + return resultMap.values().stream() .map(searchResult -> { double finalScore = rrfScores.get(searchResult.getPostId().toString()); return searchResult.toBuilder() @@ -272,41 +281,6 @@ private List calculateRRF(List> lexicalHits, Lis .collect(Collectors.toList()); } - private void processHitsForRRF(List> hits, - Map rankMap, - Map rrfScores, - Map combinedResults) { - hits.forEach(hit -> { - String docId = hit.id(); - double score = 1.0 / (generalSearchProperties.getRRF_K() + rankMap.get(docId)); - rrfScores.merge(docId, score, Double::sum); - - SearchResult newResult = mapToSearchResult(hit); - - if (!combinedResults.containsKey(docId)) { - combinedResults.put(docId, newResult); - } else { - SearchResult existing = combinedResults.get(docId); - boolean needUpdate = false; - SearchResult.SearchResultBuilder builder = existing.toBuilder(); - - if (existing.getTitleVector() == null && newResult.getTitleVector() != null) { - builder.titleVector(newResult.getTitleVector()); - needUpdate = true; - } - - if (existing.getSummaryVector() == null && newResult.getSummaryVector() != null) { - builder.summaryVector(newResult.getSummaryVector()); - needUpdate = true; - } - - if (needUpdate) { - combinedResults.put(docId, builder.build()); - } - } - }); - } - private SearchResult mapToSearchResult(Hit hit) { PostDocument doc = hit.source(); double score = Objects.requireNonNullElse(hit.score(), 0.0); diff --git a/src/main/java/com/techfork/domain/user/document/UserProfileDocument.java b/src/main/java/com/techfork/domain/user/document/UserProfileDocument.java index 716c96c..b23c774 100644 --- a/src/main/java/com/techfork/domain/user/document/UserProfileDocument.java +++ b/src/main/java/com/techfork/domain/user/document/UserProfileDocument.java @@ -36,28 +36,33 @@ public class UserProfileDocument { @Field(type = FieldType.Keyword) private List interests; + @Field(type = FieldType.Keyword) + private List keyKeywords; + @Field(type = FieldType.Date) @Transient private LocalDateTime generatedAt; @Builder private UserProfileDocument(Long userId, String profileText, float[] profileVector, - List interests, LocalDateTime generatedAt) { + List interests, List keyKeywords, LocalDateTime generatedAt) { this.id = String.valueOf(userId); this.userId = userId; this.profileText = profileText; this.profileVector = profileVector; this.interests = interests; + this.keyKeywords = keyKeywords; this.generatedAt = generatedAt; } public static UserProfileDocument create(Long userId, String profileText, float[] profileVector, - List interests) { + List interests, List keyKeywords) { return UserProfileDocument.builder() .userId(userId) .profileText(profileText) .profileVector(profileVector) .interests(interests) + .keyKeywords(keyKeywords) .generatedAt(LocalDateTime.now()) .build(); } diff --git a/src/main/java/com/techfork/domain/user/service/UserProfileService.java b/src/main/java/com/techfork/domain/user/service/UserProfileService.java index 7a93d49..a45ad9c 100644 --- a/src/main/java/com/techfork/domain/user/service/UserProfileService.java +++ b/src/main/java/com/techfork/domain/user/service/UserProfileService.java @@ -25,6 +25,7 @@ import org.springframework.stereotype.Service; import org.springframework.transaction.annotation.Transactional; +import java.util.Arrays; import java.util.List; import java.util.stream.Collectors; @@ -57,14 +58,17 @@ public void generateUserProfile(Long userId) { public void generateUserProfileSync(Long userId) { try { UserActivityData activityData = collectUserActivityData(userId); - String profileText = generateProfileTextWithLLM(activityData); - float[] profileVector = generateEmbeddingVector(profileText); + String llmResponse = generateProfileTextWithLLM(activityData); + + ProfileAndKeywords parsed = parseProfileAndKeywords(llmResponse); + float[] profileVector = generateEmbeddingVector(parsed.profileText); UserProfileDocument profileDocument = UserProfileDocument.create( userId, - profileText, + parsed.profileText, profileVector, - activityData.interests + activityData.interests, + parsed.keyKeywords ); userProfileDocumentRepository.save(profileDocument); @@ -143,7 +147,7 @@ private String generateProfileTextWithLLM(UserActivityData data) { private String buildProfileGenerationPrompt(UserActivityData data) { return String.format(""" - 아래 사용자의 활동 데이터를 분석하여 검색 고도화와 포스트 추천에 최적화된 프로필을 생성해주세요. + 아래 사용자의 활동 데이터를 분석하여 검색 리랭킹과 포스트 추천에 최적화된 프로필을 생성해주세요. ## 사용자 데이터 @@ -161,28 +165,29 @@ private String buildProfileGenerationPrompt(UserActivityData data) { ## 요구사항 - 다음 형식으로 구조화된 프로필을 생성해주세요: - - 1. **기술적 관심사 요약** (2-3문장) - - 사용자가 주로 관심을 갖는 기술 스택, 프레임워크, 도구 - - 선호하는 개발 분야 (백엔드, 프론트엔드, AI, 인프라 등) + 반드시 아래 형식으로 응답해주세요: - 2. **콘텐츠 선호 패턴** (2-3문장) - - 읽은 포스트와 스크랩한 포스트를 분석하여 선호하는 주제와 기술 파악 - - 선호하는 회사/팀이나 콘텐츠 유형 (튜토리얼, 아키텍처, 트러블슈팅 등) + ### PROFILE + 사용자의 기술적 관심사, 학습 패턴, 선호도를 의미 밀도 높고 풍부하게 표현한 텍스트를 작성하세요 (200-300자 정도). - 3. **검색 의도 분석** (2-3문장) - - 검색 기록에서 드러나는 학습 목적이나 해결하려는 문제 - - 반복되는 검색 주제나 패턴 + 다음 내용을 모두 포함하되 자연스러운 문장으로 작성: + 1. 주요 관심 기술 스택과 개발 분야 (백엔드/프론트엔드/인프라/AI 등) + 2. 선호하는 주제와 학습 방향 (아키텍처 설계, 성능 최적화, 트러블슈팅, 신기술 탐구 등) + 3. 읽은 포스트와 검색 기록에서 드러나는 구체적인 관심사 + 4. 현재 해결하려는 문제나 학습 중인 영역 + 5. 콘텐츠 선호 패턴 (심화 기술, 실전 경험, 튜토리얼 등) - 4. **추천 키워드** (쉼표로 구분된 15-20개의 키워드) - - 검색 쿼리 확장에 사용할 관련 기술 용어 - - 유사한 관심사를 가진 사용자가 찾을 만한 키워드 - - 영문과 한글 키워드 모두 포함 + 주의사항: + - 마크다운 없이 순수 텍스트로만 작성 (볼드, 이탤릭, 리스트, 번호 금지) + - 구체적인 기술 용어를 많이 사용하여 임베딩 품질 향상 + - "관심이 있습니다", "선호합니다" 같은 메타 표현 대신 직접적인 기술 용어 나열 - 5. **프로필 요약** (1-2문장, 벡터 임베딩 최적화용) - - 사용자의 기술적 페르소나를 한 줄로 압축 - - 추천 시스템이 유사 사용자를 찾는데 활용할 핵심 설명 + ### KEYWORDS + 사용자의 현재 관심사를 가장 잘 대표하는 핵심 키워드 3-5개를 쉼표로 구분하여 나열하세요. + - 구체적이고 검색 의도가 명확한 키워드만 선택 + - BM25 검색에 사용되므로 검색어로 자주 쓰일 만한 용어 선택 + - 예: Kubernetes, React hooks, 분산 트랜잭션, 성능 최적화, MSA + - 영문과 한글 혼용 가능 데이터가 부족한 경우 관심 기술 스택을 기반으로 일반적인 프로필을 생성해주세요. """, @@ -247,6 +252,43 @@ private String convertReadingDurationToNaturalLanguage(Integer durationSeconds) } } + private ProfileAndKeywords parseProfileAndKeywords(String llmResponse) { + String profileText = ""; + List keyKeywords = List.of(); + + try { + // PROFILE 섹션 추출 + int profileStart = llmResponse.indexOf("### PROFILE"); + int keywordsStart = llmResponse.indexOf("### KEYWORDS"); + + if (profileStart != -1 && keywordsStart != -1) { + profileText = llmResponse.substring(profileStart + "### PROFILE".length(), keywordsStart) + .trim(); + + String keywordsSection = llmResponse.substring(keywordsStart + "### KEYWORDS".length()) + .trim(); + + // 쉼표로 구분된 키워드 파싱 + keyKeywords = Arrays.stream(keywordsSection.split(",")) + .map(String::trim) + .filter(s -> !s.isEmpty()) + .limit(5) // 최대 5개 + .toList(); + } else { + // 파싱 실패 시 전체 텍스트를 프로필로 사용 + log.warn("Failed to parse LLM response sections, using full text as profile"); + profileText = llmResponse; + } + } catch (Exception e) { + log.error("Error parsing LLM response", e); + profileText = llmResponse; + } + + return new ProfileAndKeywords(profileText, keyKeywords); + } + + private record ProfileAndKeywords(String profileText, List keyKeywords) {} + private record UserActivityData( List interests, List readPostData, diff --git a/src/main/java/com/techfork/global/elasticsearch/query/VectorQueryBuilder.java b/src/main/java/com/techfork/global/elasticsearch/query/VectorQueryBuilder.java index c435993..16a2c98 100644 --- a/src/main/java/com/techfork/global/elasticsearch/query/VectorQueryBuilder.java +++ b/src/main/java/com/techfork/global/elasticsearch/query/VectorQueryBuilder.java @@ -3,6 +3,7 @@ import co.elastic.clients.elasticsearch._types.KnnSearch; import co.elastic.clients.elasticsearch._types.query_dsl.Query; import java.util.List; +import java.util.Set; /** * Elasticsearch 벡터 검색 쿼리 빌더 인터페이스 @@ -10,6 +11,14 @@ */ public interface VectorQueryBuilder { + /** + * 읽은 글 제외를 위한 필터 쿼리 생성 (Pre-filtering용) + * + * @param readPostIds 제외할 게시글 ID 목록 + * @return Elasticsearch Query 객체 + */ + Query createExcludeFilter(Set readPostIds); + /** * 네이티브 k-NN 검색 객체 리스트 생성 * (title, summary, content 필드에 대한 k-NN 검색) @@ -40,11 +49,13 @@ List createKnnSearches( ); /** - * 랜덤 점수를 위한 function_score 쿼리 생성 + * BM25 키워드 검색 쿼리 생성 * - * @param randomSeed 랜덤 시드 - * @param randomWeight 랜덤 가중치 - * @return function_score 쿼리 + * @param keywords 검색할 키워드 리스트 + * @param titleBoost 제목 필드 가중치 + * @param summaryBoost 요약 필드 가중치 + * @param contentBoost 본문 필드 가중치 + * @return BM25 검색 Query 객체 (키워드가 없으면 null) */ - Query createRandomScoreQuery(long randomSeed, double randomWeight); + Query createBm25Query(List keywords, float titleBoost, float summaryBoost, float contentBoost); } \ No newline at end of file diff --git a/src/main/java/com/techfork/global/elasticsearch/query/VectorSearchQueryBuilder.java b/src/main/java/com/techfork/global/elasticsearch/query/VectorSearchQueryBuilder.java index 3a978bf..7b49748 100644 --- a/src/main/java/com/techfork/global/elasticsearch/query/VectorSearchQueryBuilder.java +++ b/src/main/java/com/techfork/global/elasticsearch/query/VectorSearchQueryBuilder.java @@ -1,7 +1,8 @@ package com.techfork.global.elasticsearch.query; +import co.elastic.clients.elasticsearch._types.FieldValue; import co.elastic.clients.elasticsearch._types.KnnSearch; -import co.elastic.clients.elasticsearch._types.query_dsl.FunctionBoostMode; +import co.elastic.clients.elasticsearch._types.query_dsl.ChildScoreMode; import co.elastic.clients.elasticsearch._types.query_dsl.Query; import lombok.AccessLevel; import lombok.NoArgsConstructor; @@ -9,6 +10,7 @@ import java.util.ArrayList; import java.util.List; +import java.util.Set; /** * Elasticsearch 벡터 검색 쿼리 빌더 구현체 @@ -18,6 +20,28 @@ @NoArgsConstructor(access = AccessLevel.PRIVATE) public class VectorSearchQueryBuilder implements VectorQueryBuilder { + @Override + public Query createExcludeFilter(Set readPostIds) { + if (readPostIds == null || readPostIds.isEmpty()) { + return null; + } + + List excludeValues = readPostIds.stream() + .map(FieldValue::of) + .toList(); + + return Query.of(q -> q + .bool(b -> b + .mustNot(mn -> mn + .terms(t -> t + .field("postId") + .terms(v -> v.value(excludeValues)) + ) + ) + ) + ); + } + @Override public List createKnnSearches( String titleField, @@ -40,10 +64,10 @@ public List createKnnSearches( if (titleWeight > 0) { knnSearches.add(KnnSearch.of(ks -> { ks.field(titleField) - .queryVector(vectorList) - .k(k) - .numCandidates(numCandidates) - .boost(titleWeight); + .queryVector(vectorList) + .k(k) + .numCandidates(numCandidates) + .boost(titleWeight); if (filter != null) { ks.filter(filter); } @@ -54,10 +78,10 @@ public List createKnnSearches( if (summaryWeight > 0) { knnSearches.add(KnnSearch.of(ks -> { ks.field(summaryField) - .queryVector(vectorList) - .k(k) - .numCandidates(numCandidates) - .boost(summaryWeight); + .queryVector(vectorList) + .k(k) + .numCandidates(numCandidates) + .boost(summaryWeight); if (filter != null) { ks.filter(filter); } @@ -68,10 +92,10 @@ public List createKnnSearches( if (contentWeight > 0 && contentField != null) { knnSearches.add(KnnSearch.of(ks -> { ks.field(contentField) - .queryVector(vectorList) - .k(k) - .numCandidates(numCandidates) - .boost(contentWeight); + .queryVector(vectorList) + .k(k) + .numCandidates(numCandidates) + .boost(contentWeight); if (filter != null) { ks.filter(filter); } @@ -83,18 +107,43 @@ public List createKnnSearches( } @Override - public Query createRandomScoreQuery(long randomSeed, double randomWeight) { + public Query createBm25Query(List keywords, float titleBoost, float summaryBoost, float contentBoost) { + if (keywords == null || keywords.isEmpty()) { + return null; + } + + String combinedKeywords = String.join(" ", keywords); + return Query.of(q -> q - .functionScore(fs -> fs - .query(mq -> mq.matchAll(m -> m)) - .functions(fn -> fn - .randomScore(rs -> rs - .seed(String.valueOf(randomSeed)) - .field("_seq_no") + .bool(b -> b + .should(s -> s + .match(m -> m + .field("title") + .query(combinedKeywords) + .boost(titleBoost) + ) + ) + .should(s -> s + .match(m -> m + .field("summary") + .query(combinedKeywords) + .boost(summaryBoost) + ) + ) + .should(s -> s + .nested(n -> n + .path("contentChunks") + .scoreMode(ChildScoreMode.Max) + .query(nq -> nq + .match(m -> m + .field("contentChunks.text") + .query(combinedKeywords) + ) + ) + .boost(contentBoost) ) - .weight(randomWeight) ) - .boostMode(FunctionBoostMode.Sum) + .minimumShouldMatch("1") ) ); } diff --git a/src/main/java/com/techfork/global/util/RrfScorer.java b/src/main/java/com/techfork/global/util/RrfScorer.java new file mode 100644 index 0000000..2360f43 --- /dev/null +++ b/src/main/java/com/techfork/global/util/RrfScorer.java @@ -0,0 +1,48 @@ +package com.techfork.global.util; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * Reciprocal Rank Fusion (RRF) 스코어 계산 유틸리티 + * 여러 검색 결과를 결합하여 하나의 통합 점수를 생성 + * k=60 고정 사용 + */ +public class RrfScorer { + + private static final int K = 60; + + /** + * RRF 스코어 계산 (k=60 고정) + * + * @param resultsLists 여러 검색 결과 리스트 (각 리스트는 순위대로 정렬되어 있어야 함) + * @param 결과 항목의 타입 + * @return 각 항목의 ID와 RRF 스코어 맵 + */ + public static Map calculateRrfScores(List> resultsLists) { + Map rrfScores = new HashMap<>(); + + for (List results : resultsLists) { + for (int rank = 0; rank < results.size(); rank++) { + T item = results.get(rank); + double score = 1.0 / (K + rank + 1); + rrfScores.merge(item, score, Double::sum); + } + } + + return rrfScores; + } + + /** + * 두 개의 검색 결과를 RRF로 결합 (k=60 고정) + * + * @param firstResults 첫 번째 검색 결과 + * @param secondResults 두 번째 검색 결과 + * @param 결과 항목의 타입 + * @return 각 항목의 ID와 RRF 스코어 맵 + */ + public static Map calculateRrfScores(List firstResults, List secondResults) { + return calculateRrfScores(List.of(firstResults, secondResults)); + } +} \ No newline at end of file diff --git a/src/main/resources/application.yml b/src/main/resources/application.yml index cd30af7..41491fe 100644 --- a/src/main/resources/application.yml +++ b/src/main/resources/application.yml @@ -94,16 +94,16 @@ webhook: url: ${DISCORD_WEBHOOK_URL} recommendation: - knn-search-size: 100 - num-candidates: 200 + knn-search-size: 80 + num-candidates: 180 mmr-final-size: 30 lambda: 0.95 active-user-hours: 24 # 임베딩 가중치 설정 (합계 1.0) embedding-weights: - title: 0.5 # 제목 중요도 50% - summary: 0.5 # 요약 중요도 50% - content: 0.0 # 콘텐츠 청크 중요도 0% (제외) + title: 0.4 # 제목 중요도 50% + summary: 0.4 # 요약 중요도 50% + content: 0.2 # 콘텐츠 청크 중요도 0% (제외) # 시간 감쇠 가중치 설정 time-decay: days-7: 1.3 # 최근 7일: +30% diff --git a/src/test/java/com/techfork/domain/recommendation/evaluation/KValueComparisonTest.java b/src/test/java/com/techfork/domain/recommendation/evaluation/KValueComparisonTest.java new file mode 100644 index 0000000..278f60a --- /dev/null +++ b/src/test/java/com/techfork/domain/recommendation/evaluation/KValueComparisonTest.java @@ -0,0 +1,250 @@ +package com.techfork.domain.recommendation.evaluation; + +import com.techfork.domain.recommendation.config.RecommendationProperties; +import com.techfork.domain.user.entity.User; +import lombok.Builder; +import lombok.Getter; +import lombok.extern.slf4j.Slf4j; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Tag; +import org.junit.jupiter.api.Test; + +import java.util.*; + +/** + * kNN 검색 크기(k) 값에 따른 성능 및 품질 비교 테스트 + */ +@Tag("evaluation") +@Slf4j +public class KValueComparisonTest extends RecommendationTestBase { + + @Test + @DisplayName("knnSearchSize와 numCandidates 값 비교 평가") + void compareKValues() { + log.info("===== K 값에 따른 성능 및 품질 비교 ====="); + log.info("Ground-Truth: {} 명 사용자", cachedGroundTruth.size()); + + List kConfigs = createKConfigs(); + List testUsers = getTestUsers(); + log.info("테스트 사용자: {} 명", testUsers.size()); + + printKComparisonHeader(); + List results = evaluateAllKConfigs(kConfigs, testUsers); + printBestKResult(results); + } + + /** + * 테스트할 K 값 설정 생성 + */ + private List createKConfigs() { + return Arrays.asList( + // 현재 기본값 + KConfig.builder().name("현재 (50/100)") + .knnSearchSize(50).numCandidates(100).build(), + + KConfig.builder().name("중간-하 (60/120)") + .knnSearchSize(60).numCandidates(120).build(), + + // 중간 값 + KConfig.builder().name("중간 (70/150)") + .knnSearchSize(70).numCandidates(150).build(), + + KConfig.builder().name("중간-상 (80/180)") + .knnSearchSize(80).numCandidates(180).build(), + + // 이전 값 + KConfig.builder().name("이전 (100/200)") + .knnSearchSize(100).numCandidates(200).build() + ); + } + + /** + * 모든 K 설정에 대해 평가 수행 + */ + private List evaluateAllKConfigs(List kConfigs, List testUsers) { + List results = new ArrayList<>(); + + for (KConfig kConfig : kConfigs) { + long startTime = System.currentTimeMillis(); + + // Properties 생성 + RecommendationProperties properties = new RecommendationProperties(); + properties.setKnnSearchSize(kConfig.knnSearchSize); + properties.setNumCandidates(kConfig.numCandidates); + properties.setMmrCandidateSize(80); + properties.setMmrFinalSize(30); + properties.setLambda(1.0); // 다양성 제외, 관련성만 + + // 가중치는 최적값으로 고정 (제목+요약 중심) + RecommendationProperties.EmbeddingWeights weights = new RecommendationProperties.EmbeddingWeights(); + weights.setTitle(0.4f); + weights.setSummary(0.4f); + weights.setContent(0.2f); + properties.setEmbeddingWeights(weights); + + // 평가 수행 - UserMetrics 수집 + List userMetrics = testUsers.stream() + .map(user -> evaluateUserWithGroundTruth(user, properties)) + .filter(Optional::isPresent) + .map(Optional::get) + .toList(); + + // 평균 메트릭 계산 + KMetrics avgMetrics = calculateAverageKMetrics(userMetrics); + + long elapsedTime = System.currentTimeMillis() - startTime; + + KResult result = KResult.builder() + .name(kConfig.name) + .knnSearchSize(kConfig.knnSearchSize) + .numCandidates(kConfig.numCandidates) + .metrics(avgMetrics) + .executionTimeMs(elapsedTime) + .build(); + + results.add(result); + printKResult(result); + } + + return results; + } + + private KMetrics calculateAverageKMetrics(List userMetrics) { + double r4 = userMetrics.stream().mapToDouble(UserMetrics::getRecall4).average().orElse(0.0); + double n4 = userMetrics.stream().mapToDouble(UserMetrics::getNdcg4).average().orElse(0.0); + double r8 = userMetrics.stream().mapToDouble(UserMetrics::getRecall8).average().orElse(0.0); + double n8 = userMetrics.stream().mapToDouble(UserMetrics::getNdcg8).average().orElse(0.0); + double r30 = userMetrics.stream().mapToDouble(UserMetrics::getRecall30).average().orElse(0.0); + double n30 = userMetrics.stream().mapToDouble(UserMetrics::getNdcg30).average().orElse(0.0); + + return KMetrics.builder() + .recallAt4(r4) + .ndcgAt4(n4) + .recallAt8(r8) + .ndcgAt8(n8) + .recallAt30(r30) + .ndcgAt30(n30) + .build(); + } + + private void printKComparisonHeader() { + log.info(""); + log.info("설정 | K값 | Candidates | R@4 | R@8 | R@30 | nDCG@4 | nDCG@8 | nDCG@30 | 실행시간"); + log.info("----------------------------------------------------------------------------------------------"); + } + + private void printKResult(KResult result) { + log.info(String.format("%-30s | %-9s | %-10s | %.4f | %.4f | %.4f | %.4f | %.4f | %.4f | %dms", + result.name, + result.knnSearchSize, + result.numCandidates, + result.metrics.recallAt4, + result.metrics.recallAt8, + result.metrics.recallAt30, + result.metrics.ndcgAt4, + result.metrics.ndcgAt8, + result.metrics.ndcgAt30, + result.executionTimeMs + )); + } + + private void printBestKResult(List results) { + log.info(""); + log.info("===== 최고 성능 K 값 조합 ====="); + + // Recall@8 최고 + KResult bestRecall = results.stream() + .max(Comparator.comparing(r -> r.metrics.recallAt8)) + .orElse(null); + + // nDCG@8 최고 + KResult bestNdcg = results.stream() + .max(Comparator.comparing(r -> r.metrics.ndcgAt8)) + .orElse(null); + + // 균형 점수 최고 (Recall@8 + nDCG@8 평균) + KResult bestBalance = results.stream() + .max(Comparator.comparing(r -> (r.metrics.recallAt8 + r.metrics.ndcgAt8) / 2.0)) + .orElse(null); + + // 실행 시간 최단 + KResult fastest = results.stream() + .min(Comparator.comparing(r -> r.executionTimeMs)) + .orElse(null); + + log.info(""); + log.info("[Recall@8 최고]"); + if (bestRecall != null) { + log.info(String.format("%s (K=%d, C=%d) | R@8: %.4f | nDCG@8: %.4f | 시간: %dms", + bestRecall.name, bestRecall.knnSearchSize, bestRecall.numCandidates, + bestRecall.metrics.recallAt8, bestRecall.metrics.ndcgAt8, bestRecall.executionTimeMs)); + } + + log.info(""); + log.info("[nDCG@8 최고]"); + if (bestNdcg != null) { + log.info(String.format("%s (K=%d, C=%d) | R@8: %.4f | nDCG@8: %.4f | 시간: %dms", + bestNdcg.name, bestNdcg.knnSearchSize, bestNdcg.numCandidates, + bestNdcg.metrics.recallAt8, bestNdcg.metrics.ndcgAt8, bestNdcg.executionTimeMs)); + } + + log.info(""); + log.info("[균형 점수 최고 (R@8 + nDCG@8 평균)]"); + if (bestBalance != null) { + double balanceScore = (bestBalance.metrics.recallAt8 + bestBalance.metrics.ndcgAt8) / 2.0; + log.info(String.format("%s (K=%d, C=%d) | R@8: %.4f | nDCG@8: %.4f | 균형: %.4f | 시간: %dms", + bestBalance.name, bestBalance.knnSearchSize, bestBalance.numCandidates, + bestBalance.metrics.recallAt8, bestBalance.metrics.ndcgAt8, balanceScore, + bestBalance.executionTimeMs)); + } + + log.info(""); + log.info("[실행 시간 최단]"); + if (fastest != null) { + log.info(String.format("%s (K=%d, C=%d) | R@8: %.4f | nDCG@8: %.4f | 시간: %dms", + fastest.name, fastest.knnSearchSize, fastest.numCandidates, + fastest.metrics.recallAt8, fastest.metrics.ndcgAt8, fastest.executionTimeMs)); + } + + log.info(""); + log.info("===== 성능/품질 트레이드오프 분석 ====="); + results.forEach(r -> { + double efficiency = (r.metrics.recallAt8 + r.metrics.ndcgAt8) / 2.0 / (r.executionTimeMs / 1000.0); + log.info(String.format("%s: 효율성 지수 = %.4f (품질: %.4f, 시간: %.1fs)", + r.name, + efficiency, + (r.metrics.recallAt8 + r.metrics.ndcgAt8) / 2.0, + r.executionTimeMs / 1000.0 + )); + }); + } + + @Getter + @Builder + private static class KConfig { + private String name; + private int knnSearchSize; + private int numCandidates; + } + + @Getter + @Builder + private static class KMetrics { + private double recallAt4; + private double ndcgAt4; + private double recallAt8; + private double ndcgAt8; + private double recallAt30; + private double ndcgAt30; + } + + @Getter + @Builder + private static class KResult { + private String name; + private int knnSearchSize; + private int numCandidates; + private KMetrics metrics; + private long executionTimeMs; + } +} \ No newline at end of file diff --git a/src/test/java/com/techfork/domain/recommendation/evaluation/RecommendationEvaluationService.java b/src/test/java/com/techfork/domain/recommendation/evaluation/RecommendationEvaluationService.java new file mode 100644 index 0000000..cf61246 --- /dev/null +++ b/src/test/java/com/techfork/domain/recommendation/evaluation/RecommendationEvaluationService.java @@ -0,0 +1,200 @@ +package com.techfork.domain.recommendation.evaluation; + +import co.elastic.clients.elasticsearch.ElasticsearchClient; +import co.elastic.clients.elasticsearch.core.SearchResponse; +import co.elastic.clients.elasticsearch.core.search.Hit; +import co.elastic.clients.elasticsearch._types.query_dsl.Query; +import co.elastic.clients.elasticsearch._types.KnnSearch; +import com.techfork.domain.activity.repository.ReadPostRepository; +import com.techfork.domain.post.document.PostDocument; +import com.techfork.domain.post.entity.Post; +import com.techfork.domain.post.repository.PostRepository; +import com.techfork.domain.recommendation.config.RecommendationProperties; +import com.techfork.domain.recommendation.entity.RecommendedPost; +import com.techfork.domain.recommendation.repository.RecommendationHistoryRepository; +import com.techfork.domain.recommendation.repository.RecommendedPostRepository; +import com.techfork.domain.recommendation.service.LlmRecommendationService; +import com.techfork.domain.recommendation.service.MmrService; +import com.techfork.domain.recommendation.service.MmrService.MmrCandidate; +import com.techfork.domain.recommendation.service.MmrService.MmrResult; +import com.techfork.domain.user.document.UserProfileDocument; +import com.techfork.domain.user.entity.User; +import com.techfork.domain.user.repository.UserProfileDocumentRepository; +import com.techfork.global.elasticsearch.query.VectorQueryBuilder; +import com.techfork.global.util.TimeDecayStrategy; +import lombok.extern.slf4j.Slf4j; +import org.springframework.stereotype.Service; + +import java.io.IOException; +import java.util.*; +import java.util.concurrent.CompletableFuture; + +/** + * 추천 시스템 성능 평가를 위한 전용 서비스 + * LlmRecommendationService를 상속하여 RRF, MMR 로직 재사용 + */ +@Slf4j +@Service +public class RecommendationEvaluationService extends LlmRecommendationService { + + private final UserProfileDocumentRepository userProfileDocumentRepository; + private final VectorQueryBuilder vectorQueryBuilder; + private final ElasticsearchClient elasticsearchClient; + + private static final String POSTS_INDEX = "posts"; + private static final String TITLE_EMBEDDING_FIELD = "titleEmbedding"; + private static final String SUMMARY_EMBEDDING_FIELD = "summaryEmbedding"; + private static final String CONTENT_CHUNKS_EMBEDDING_FIELD = "contentChunks.embedding"; + + public RecommendationEvaluationService( + ElasticsearchClient elasticsearchClient, + UserProfileDocumentRepository userProfileDocumentRepository, + RecommendedPostRepository recommendedPostRepository, + RecommendationHistoryRepository recommendationHistoryRepository, + ReadPostRepository readPostRepository, + PostRepository postRepository, + MmrService mmrService, + TimeDecayStrategy timeDecayStrategy, + RecommendationProperties properties, + VectorQueryBuilder vectorQueryBuilder + ) { + super(elasticsearchClient, userProfileDocumentRepository, recommendedPostRepository, + recommendationHistoryRepository, readPostRepository, postRepository, + mmrService, timeDecayStrategy, properties, vectorQueryBuilder); + this.elasticsearchClient = elasticsearchClient; + this.userProfileDocumentRepository = userProfileDocumentRepository; + this.vectorQueryBuilder = vectorQueryBuilder; + } + + /** + * 추천 생성 (평가 전용 - Train/Test Split 지원) + */ + public List generateRecommendationsForEvaluation(User user, Set trainPostIds, RecommendationProperties properties) { + long totalStartTime = System.currentTimeMillis(); + + Optional profileOpt = userProfileDocumentRepository.findByUserId(user.getId()); + if (profileOpt.isEmpty() || profileOpt.get().getProfileVector() == null) { + return Collections.emptyList(); + } + + UserProfileDocument profile = profileOpt.get(); + float[] userProfileVector = profile.getProfileVector(); + List keyKeywords = profile.getKeyKeywords(); + + try { + List candidates = searchCandidatesWithCustomReadHistory(userProfileVector, keyKeywords, trainPostIds, properties); + + if (candidates.isEmpty()) { + return Collections.emptyList(); + } + + // MMR 적용 (테스트용 properties 사용) + long mmrStartTime = System.currentTimeMillis(); + MmrService mmrService = new MmrService(properties); + List mmrResults = mmrService.applyMmr(candidates); + long mmrElapsedTime = System.currentTimeMillis() - mmrStartTime; + log.info("[EVAL] MMR 실행 시간: {}ms (후보 {}개 → 결과 {}개)", mmrElapsedTime, candidates.size(), mmrResults.size()); + + List result = mmrResults.stream() + .map(MmrResult::getPostId) + .toList(); + + long totalElapsedTime = System.currentTimeMillis() - totalStartTime; + log.info("[EVAL] 전체 추천 로직 실행 시간: {}ms (사용자 ID: {})", totalElapsedTime, user.getId()); + + return result; + + } catch (Exception e) { + log.error("사용자 {} 평가용 추천 생성 실패", user.getId(), e); + return Collections.emptyList(); + } + } + + private List searchCandidatesWithCustomReadHistory( + float[] userProfileVector, + List keyKeywords, + Set readPostIds, + RecommendationProperties properties) throws IOException { + + RecommendationProperties.EmbeddingWeights weights = properties.getEmbeddingWeights(); + Query filterQuery = vectorQueryBuilder.createExcludeFilter(readPostIds); + + // 1. kNN 검색 쿼리 준비 + List knnSearches = vectorQueryBuilder.createKnnSearches( + TITLE_EMBEDDING_FIELD, SUMMARY_EMBEDDING_FIELD, CONTENT_CHUNKS_EMBEDDING_FIELD, + userProfileVector, weights.getTitle(), weights.getSummary(), weights.getContent(), + properties.getKnnSearchSize(), properties.getNumCandidates(), filterQuery + ); + + // 2. BM25 검색 쿼리 준비 + Query bm25Query = vectorQueryBuilder.createBm25Query( + keyKeywords, weights.getTitle(), weights.getSummary(), weights.getContent() + ); + + // 3. kNN과 BM25 검색 병렬 실행 + long searchStartTime = System.currentTimeMillis(); + + CompletableFuture>> vectorSearchFuture = CompletableFuture.supplyAsync(() -> { + try { + long knnStartTime = System.currentTimeMillis(); + SearchResponse response = elasticsearchClient.search(s -> s + .index(POSTS_INDEX) + .knn(knnSearches) + .size(properties.getKnnSearchSize()), + PostDocument.class + ); + long knnElapsedTime = System.currentTimeMillis() - knnStartTime; + log.info("[EVAL] kNN 검색 실행 시간: {}ms", knnElapsedTime); + return response.hits().hits(); + } catch (IOException e) { + log.error("kNN 검색 실패", e); + return Collections.emptyList(); + } + }); + + CompletableFuture>> keywordSearchFuture = CompletableFuture.supplyAsync(() -> { + // 키워드가 없으면 BM25 검색 생략 + if (bm25Query == null) { + log.debug("[EVAL] 키워드가 없어 BM25 검색 생략"); + return Collections.emptyList(); + } + try { + long bm25StartTime = System.currentTimeMillis(); + SearchResponse response = elasticsearchClient.search(s -> s + .index(POSTS_INDEX) + .query(q -> q.bool(b -> { + b.must(bm25Query); + if (filterQuery != null) b.filter(filterQuery); + return b; + })) + .size(properties.getKnnSearchSize()), + PostDocument.class + ); + long bm25ElapsedTime = System.currentTimeMillis() - bm25StartTime; + log.info("[EVAL] BM25 검색 실행 시간: {}ms", bm25ElapsedTime); + return response.hits().hits(); + } catch (IOException e) { + log.error("BM25 검색 실패", e); + return Collections.emptyList(); + } + }); + + // 4. 두 검색 완료 대기 + CompletableFuture allSearches = CompletableFuture.allOf(vectorSearchFuture, keywordSearchFuture); + allSearches.join(); + + List> vectorHits = vectorSearchFuture.join(); + List> keywordHits = keywordSearchFuture.join(); + + long searchTotalTime = System.currentTimeMillis() - searchStartTime; + log.info("[EVAL] 검색 총 소요 시간: {}ms (kNN: {}개, BM25: {}개)", searchTotalTime, vectorHits.size(), keywordHits.size()); + + // 5. RRF로 결합 (부모 클래스의 protected 메서드 사용) + long rrfStartTime = System.currentTimeMillis(); + List candidates = applyRrf(vectorHits, keywordHits); + long rrfElapsedTime = System.currentTimeMillis() - rrfStartTime; + log.info("[EVAL] RRF 결합 실행 시간: {}ms (결과: {}개)", rrfElapsedTime, candidates.size()); + + return candidates; + } +} \ No newline at end of file diff --git a/src/test/java/com/techfork/domain/recommendation/evaluation/RecommendationTestBase.java b/src/test/java/com/techfork/domain/recommendation/evaluation/RecommendationTestBase.java index 2d9b0ed..44f5b50 100644 --- a/src/test/java/com/techfork/domain/recommendation/evaluation/RecommendationTestBase.java +++ b/src/test/java/com/techfork/domain/recommendation/evaluation/RecommendationTestBase.java @@ -1,21 +1,11 @@ package com.techfork.domain.recommendation.evaluation; -import co.elastic.clients.elasticsearch.ElasticsearchClient; import com.techfork.domain.activity.repository.ReadPostRepository; import com.techfork.domain.post.repository.PostDocumentRepository; -import com.techfork.domain.post.repository.PostRepository; import com.techfork.domain.recommendation.config.RecommendationProperties; -import com.techfork.domain.recommendation.repository.RecommendationHistoryRepository; -import com.techfork.domain.recommendation.repository.RecommendedPostRepository; -import com.techfork.domain.recommendation.service.LlmRecommendationService; -import com.techfork.domain.recommendation.service.MmrService; import com.techfork.domain.recommendation.util.EvaluationFixtureLoader; import com.techfork.domain.user.entity.User; -import com.techfork.domain.user.enums.EInterestCategory; -import com.techfork.domain.user.repository.UserProfileDocumentRepository; import com.techfork.global.common.IntegrationTestBase; -import com.techfork.global.elasticsearch.query.VectorQueryBuilder; -import com.techfork.global.util.TimeDecayStrategy; import com.techfork.global.util.VectorUtil; import lombok.AllArgsConstructor; import lombok.Builder; @@ -25,7 +15,6 @@ import org.junit.jupiter.api.TestInstance; import org.springframework.beans.factory.annotation.Autowired; -import java.io.IOException; import java.util.*; /** @@ -36,9 +25,9 @@ public abstract class RecommendationTestBase extends IntegrationTestBase { // 테스트 상수 - protected static final int K_FIRST_ROW = 4; // 첫 줄 - protected static final int K_FIRST_SCREEN = 8; // 첫 화면 - protected static final int K_DEEP_EXPLORE = 30; // 깊은 탐색 + protected static final int K_FIRST_ROW = 4; + protected static final int K_FIRST_SCREEN = 8; + protected static final int K_DEEP_EXPLORE = 30; protected static final float DEFAULT_TITLE_WEIGHT = 0.4f; protected static final float DEFAULT_SUMMARY_WEIGHT = 0.4f; @@ -50,15 +39,9 @@ public abstract class RecommendationTestBase extends IntegrationTestBase { @Autowired protected EvaluationFixtureLoader fixtureLoader; @Autowired protected RecommendationQualityService qualityService; + @Autowired protected RecommendationEvaluationService evaluationService; // 새로운 서비스 @Autowired protected PostDocumentRepository postDocumentRepository; - @Autowired protected ElasticsearchClient elasticsearchClient; - @Autowired protected UserProfileDocumentRepository userProfileDocumentRepository; - @Autowired protected RecommendedPostRepository recommendedPostRepository; - @Autowired protected RecommendationHistoryRepository recommendationHistoryRepository; @Autowired protected ReadPostRepository readPostRepository; - @Autowired protected PostRepository postRepository; - @Autowired protected TimeDecayStrategy timeDecayStrategy; - @Autowired protected VectorQueryBuilder vectorQueryBuilder; @Autowired protected com.techfork.domain.user.repository.UserRepository userRepository; protected static List cachedTestUsers; @@ -115,10 +98,6 @@ protected List getTestUsers() { return cachedTestUsers; } - protected double calculateCompositeScore(double recall, double ndcg, double ild) { - return recall * RECALL_WEIGHT + ndcg * NDCG_WEIGHT + ild * ILD_WEIGHT; - } - protected RecommendationProperties createProperties(float tw, float sw, float cw, double lambda) { RecommendationProperties props = new RecommendationProperties(); props.setKnnSearchSize(100); @@ -132,15 +111,6 @@ protected RecommendationProperties createProperties(float tw, float sw, float cw return props; } - protected LlmRecommendationService createRecommendationService(RecommendationProperties props) { - MmrService mmrService = new MmrService(props); - return new LlmRecommendationService( - elasticsearchClient, userProfileDocumentRepository, recommendedPostRepository, - recommendationHistoryRepository, readPostRepository, postRepository, - mmrService, timeDecayStrategy, props, vectorQueryBuilder - ); - } - protected EvaluationResult calculateAverageMetrics(String configName, List metrics) { double r4 = metrics.stream().mapToDouble(UserMetrics::getRecall4).average().orElse(0.0); double n4 = metrics.stream().mapToDouble(UserMetrics::getNdcg4).average().orElse(0.0); @@ -149,7 +119,8 @@ protected EvaluationResult calculateAverageMetrics(String configName, List evaluateUserWithGroundTruth(User user, LlmRecommendationService service) { + protected Optional evaluateUserWithGroundTruth(User user, RecommendationProperties props) { try { Map groundTruth = cachedGroundTruth.get(user.getId()); if (groundTruth == null || groundTruth.isEmpty()) return Optional.empty(); @@ -171,7 +139,8 @@ protected Optional evaluateUserWithGroundTruth(User user, LlmRecomm Set readIds = readPostRepository.findRecentReadPostsByUserIdWithMinDuration(user.getId(), org.springframework.data.domain.PageRequest.of(0, 10000)) .stream().map(rp -> rp.getPost().getId()).collect(java.util.stream.Collectors.toSet()); - List recIds = service.generateRecommendationsForEvaluation(user, readIds); + // 새로운 서비스 사용 + List recIds = evaluationService.generateRecommendationsForEvaluation(user, readIds, props); if (recIds.isEmpty()) return Optional.empty(); double r4 = qualityService.calculateRecall(recIds, groundTruth.keySet(), K_FIRST_ROW); @@ -187,10 +156,7 @@ protected Optional evaluateUserWithGroundTruth(User user, LlmRecomm } } - /** - * ILD 포함 평가 (Lambda 최적화용) - */ - protected Optional evaluateUserWithGroundTruthAndILD(User user, LlmRecommendationService service) { + protected Optional evaluateUserWithGroundTruthAndILD(User user, RecommendationProperties props) { try { Map groundTruth = cachedGroundTruth.get(user.getId()); if (groundTruth == null || groundTruth.isEmpty()) return Optional.empty(); @@ -198,7 +164,8 @@ protected Optional evaluateUserWithGroundTruthAndILD(User user, Llm Set readIds = readPostRepository.findRecentReadPostsByUserIdWithMinDuration(user.getId(), org.springframework.data.domain.PageRequest.of(0, 10000)) .stream().map(rp -> rp.getPost().getId()).collect(java.util.stream.Collectors.toSet()); - List recIds = service.generateRecommendationsForEvaluation(user, readIds); + // 새로운 서비스 사용 + List recIds = evaluationService.generateRecommendationsForEvaluation(user, readIds, props); if (recIds.isEmpty()) return Optional.empty(); double r4 = qualityService.calculateRecall(recIds, groundTruth.keySet(), K_FIRST_ROW); @@ -221,11 +188,10 @@ protected Optional evaluateUserWithGroundTruthAndILD(User user, Llm } protected EvaluationResult evaluateConfigWithGroundTruth(ConfigCombo config, List testUsers) { - LlmRecommendationService service = createRecommendationService( - createProperties(config.getTitleWeight(), config.getSummaryWeight(), config.getContentWeight(), config.getMmrLambda())); + RecommendationProperties props = createProperties(config.getTitleWeight(), config.getSummaryWeight(), config.getContentWeight(), config.getMmrLambda()); List metrics = testUsers.stream() - .map(user -> evaluateUserWithGroundTruth(user, service)) + .map(user -> evaluateUserWithGroundTruth(user, props)) .filter(Optional::isPresent) .map(Optional::get) .toList(); @@ -234,11 +200,10 @@ protected EvaluationResult evaluateConfigWithGroundTruth(ConfigCombo config, Lis } protected EvaluationResult evaluateConfigWithGroundTruthAndILD(ConfigCombo config, List testUsers) { - LlmRecommendationService service = createRecommendationService( - createProperties(config.getTitleWeight(), config.getSummaryWeight(), config.getContentWeight(), config.getMmrLambda())); + RecommendationProperties props = createProperties(config.getTitleWeight(), config.getSummaryWeight(), config.getContentWeight(), config.getMmrLambda()); List metrics = testUsers.stream() - .map(user -> evaluateUserWithGroundTruthAndILD(user, service)) + .map(user -> evaluateUserWithGroundTruthAndILD(user, props)) .filter(Optional::isPresent) .map(Optional::get) .toList(); @@ -268,4 +233,4 @@ protected void printLambdaOptimizationResult(EvaluationResult result) { log.info(String.format("%-25s | %.4f | %.4f | %.4f | %.4f", result.getConfigName(), result.getAvgRecall8(), result.getAvgNdcg8(), result.getAvgIld(), result.getCompositeScore())); } -} +} \ No newline at end of file diff --git a/src/test/java/com/techfork/domain/recommendation/setup/UserDataSetupAndExporter.java b/src/test/java/com/techfork/domain/recommendation/setup/UserDataSetupAndExporter.java index 6d5c301..e777701 100644 --- a/src/test/java/com/techfork/domain/recommendation/setup/UserDataSetupAndExporter.java +++ b/src/test/java/com/techfork/domain/recommendation/setup/UserDataSetupAndExporter.java @@ -55,7 +55,7 @@ public class UserDataSetupAndExporter extends IntegrationTestBase { @Autowired private FileExporter fileExporter; - private static final int USER_COUNT = 5; + private static final int USER_COUNT = 15; private static final int READ_POST_COUNT = 80; // 프로필 구성용 (읽은 글) - 1100개 데이터셋 기준 (약 7%) private static final int HOLDOUT_COUNT = 30; // Ground Truth (평가용, 숨김) - 평가 샘플 (약 2.7%) @@ -86,18 +86,37 @@ void step1_LoadPostFixtures() { @Test @Order(2) - @DisplayName("STEP 2: 테스트 사용자 5명 생성 (임베딩 포함)") + @DisplayName("STEP 2: 테스트 사용자 15명 생성 (임베딩 포함)") @Transactional @Commit void step2_CreateTestUsers() throws IOException { log.info("===== STEP 2: 테스트 사용자 생성 ====="); List> interestCombos = Arrays.asList( - Arrays.asList(EInterestCategory.BACKEND), - Arrays.asList(EInterestCategory.FRONTEND), - Arrays.asList(EInterestCategory.AI_ML), + // Backend 중심 (4명) Arrays.asList(EInterestCategory.BACKEND, EInterestCategory.DATABASE), - Arrays.asList(EInterestCategory.AI_ML, EInterestCategory.DATA_SCIENCE) + Arrays.asList(EInterestCategory.BACKEND, EInterestCategory.ARCHITECTURE), + Arrays.asList(EInterestCategory.BACKEND, EInterestCategory.SECURITY), + Arrays.asList(EInterestCategory.BACKEND, EInterestCategory.FRONTEND), + + // Frontend 중심 (3명) + Arrays.asList(EInterestCategory.FRONTEND, EInterestCategory.PRODUCT_UX), + Arrays.asList(EInterestCategory.FRONTEND, EInterestCategory.ARCHITECTURE), + Arrays.asList(EInterestCategory.FRONTEND), + + // Data & AI (3명) + Arrays.asList(EInterestCategory.AI_ML, EInterestCategory.DATA_SCIENCE), + Arrays.asList(EInterestCategory.DATA_ENGINEERING, EInterestCategory.DATABASE), + Arrays.asList(EInterestCategory.AI_ML, EInterestCategory.CLOUD), + + // DevOps & Infrastructure (3명) + Arrays.asList(EInterestCategory.DEVOPS, EInterestCategory.CLOUD), + Arrays.asList(EInterestCategory.CLOUD, EInterestCategory.ARCHITECTURE), + Arrays.asList(EInterestCategory.SYSTEMS_OS, EInterestCategory.NETWORKING), + + // Mobile (2명) + Arrays.asList(EInterestCategory.IOS, EInterestCategory.ANDROID), + Arrays.asList(EInterestCategory.IOS, EInterestCategory.PRODUCT_UX) ); Map> userGroundTruthMap = new HashMap<>(); @@ -308,6 +327,7 @@ private Map convertUserProfileToDto(UserProfileDocument profile) } dto.put("interests", profile.getInterests()); + dto.put("keyKeywords", profile.getKeyKeywords()); return dto; } diff --git a/src/test/java/com/techfork/domain/recommendation/setup/components/UserTestDataBuilder.java b/src/test/java/com/techfork/domain/recommendation/setup/components/UserTestDataBuilder.java index e014771..aa91a50 100644 --- a/src/test/java/com/techfork/domain/recommendation/setup/components/UserTestDataBuilder.java +++ b/src/test/java/com/techfork/domain/recommendation/setup/components/UserTestDataBuilder.java @@ -20,10 +20,8 @@ import org.springframework.stereotype.Component; import java.time.LocalDateTime; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import java.util.UUID; +import java.util.*; +import java.util.stream.Collectors; @Slf4j @Component @@ -97,13 +95,15 @@ public void createReadPosts(User user, List posts) { public void createScrapPosts(User user, List readPosts, int scrapCount) { LocalDateTime now = LocalDateTime.now(); - List scrabPosts = new ArrayList<>(); - List postsToScrap = new ArrayList<>(readPosts); + List postsToScrap = readPosts.stream() + .distinct() + .collect(Collectors.toList()); Collections.shuffle(postsToScrap); int actualScrapCount = Math.min(scrapCount, postsToScrap.size()); + List scrabPosts = new ArrayList<>(); for (int i = 0; i < actualScrapCount; i++) { Post post = postsToScrap.get(i); ScrabPost scrabPost = ScrabPost.create( diff --git a/src/test/java/com/techfork/domain/recommendation/util/EvaluationFixtureLoader.java b/src/test/java/com/techfork/domain/recommendation/util/EvaluationFixtureLoader.java index 9fc365c..6a3fb85 100644 --- a/src/test/java/com/techfork/domain/recommendation/util/EvaluationFixtureLoader.java +++ b/src/test/java/com/techfork/domain/recommendation/util/EvaluationFixtureLoader.java @@ -335,6 +335,7 @@ private int loadUserProfiles(Map userMap) throws IOException { Long originalUserId = ((Number) dto.get("userId")).longValue(); String profileText = (String) dto.get("profileText"); List interests = (List) dto.get("interests"); + List keyKeywords = (List) dto.get("keyKeywords"); // JSON의 원래 User ID를 실제 DB User ID로 매핑 User user = userMap.get(originalUserId); @@ -359,6 +360,7 @@ private int loadUserProfiles(Map userMap) throws IOException { .profileText(profileText) .profileVector(profileVector) .interests(interests) + .keyKeywords(keyKeywords) .build(); userProfileDocumentRepository.save(profile); diff --git a/src/test/java/com/techfork/global/configuration/IntegrationTestConfig.java b/src/test/java/com/techfork/global/configuration/IntegrationTestConfig.java index 0002fed..b0862de 100644 --- a/src/test/java/com/techfork/global/configuration/IntegrationTestConfig.java +++ b/src/test/java/com/techfork/global/configuration/IntegrationTestConfig.java @@ -15,7 +15,7 @@ public class IntegrationTestConfig { new ElasticsearchContainer("docker.elastic.co/elasticsearch/elasticsearch:8.18.0") .withEnv("xpack.security.enabled", "false") .withEnv("discovery.type", "single-node") - .withEnv("ES_JAVA_OPTS", "-Xms256m -Xmx256m"); + .withEnv("ES_JAVA_OPTS", "-Xms2g -Xmx2g"); private static final MySQLContainer mysql = new MySQLContainer<>("mysql:8.0.36"); diff --git a/src/test/resources/application-integrationtest.yml b/src/test/resources/application-integrationtest.yml index c2e3c7c..854ecce 100644 --- a/src/test/resources/application-integrationtest.yml +++ b/src/test/resources/application-integrationtest.yml @@ -128,13 +128,13 @@ resilience4j: ratelimiter: configs: default: - limit-for-period: 37 + limit-for-period: 60 limit-refresh-period: 1m timeout-duration: 10s instances: llmSummary: base-config: default - limit-for-period: 37 + limit-for-period: 60 limit-refresh-period: 1m timeout-duration: 15s llmEmbedding: