diff --git a/.gitignore b/.gitignore index f37d2ef..24e7deb 100644 --- a/.gitignore +++ b/.gitignore @@ -53,8 +53,11 @@ stop-dev-tunnel.sh /.claude ### env ### -/.env +.env ### Apple Private Keys ### src/main/resources/keys/ -*.p8 \ No newline at end of file +*.p8 + +### Test json files ### +src/test/resources/fixtures/evaluation/ \ No newline at end of file diff --git a/src/main/java/com/techfork/domain/recommendation/config/RecommendationProperties.java b/src/main/java/com/techfork/domain/recommendation/config/RecommendationProperties.java index 1ef9a6d..f85e398 100644 --- a/src/main/java/com/techfork/domain/recommendation/config/RecommendationProperties.java +++ b/src/main/java/com/techfork/domain/recommendation/config/RecommendationProperties.java @@ -21,7 +21,7 @@ public class RecommendationProperties { private Integer mmrFinalSize = 30; - private Double lambda = 0.6; + private Double lambda = 0.95; private Integer activeUserHours = 24; @@ -34,9 +34,9 @@ public class RecommendationProperties { @NoArgsConstructor @AllArgsConstructor public static class EmbeddingWeights { - private Float title = 0.4f; - private Float summary = 0.4f; - private Float content = 0.2f; + private Float title = 0.5f; + private Float summary = 0.5f; + private Float content = 0.0f; } @Getter diff --git a/src/main/java/com/techfork/domain/recommendation/service/LlmRecommendationService.java b/src/main/java/com/techfork/domain/recommendation/service/LlmRecommendationService.java index 05ca546..6a0ed90 100644 --- a/src/main/java/com/techfork/domain/recommendation/service/LlmRecommendationService.java +++ b/src/main/java/com/techfork/domain/recommendation/service/LlmRecommendationService.java @@ -1,341 +1,345 @@ -package com.techfork.domain.recommendation.service; - -import co.elastic.clients.elasticsearch.ElasticsearchClient; -import co.elastic.clients.elasticsearch._types.query_dsl.Query; -import co.elastic.clients.elasticsearch.core.SearchResponse; -import co.elastic.clients.elasticsearch.core.search.Hit; -import com.techfork.domain.activity.repository.ReadPostRepository; -import com.techfork.global.elasticsearch.query.VectorQueryBuilder; -import com.techfork.domain.post.document.PostDocument; -import com.techfork.domain.post.entity.Post; -import com.techfork.domain.post.repository.PostRepository; -import com.techfork.domain.recommendation.config.RecommendationProperties; -import com.techfork.domain.recommendation.entity.RecommendedPost; -import com.techfork.domain.recommendation.entity.RecommendationHistory; -import com.techfork.domain.recommendation.repository.RecommendedPostRepository; -import com.techfork.domain.recommendation.repository.RecommendationHistoryRepository; -import com.techfork.domain.recommendation.service.MmrService.MmrCandidate; -import com.techfork.domain.recommendation.service.MmrService.MmrResult; -import com.techfork.domain.user.document.UserProfileDocument; -import com.techfork.domain.user.entity.User; -import com.techfork.domain.user.repository.UserProfileDocumentRepository; -import com.techfork.global.util.TimeDecayStrategy; -import com.techfork.global.util.VectorUtil; -import lombok.RequiredArgsConstructor; -import lombok.extern.slf4j.Slf4j; -import org.springframework.data.domain.PageRequest; -import org.springframework.stereotype.Service; -import org.springframework.transaction.annotation.Transactional; - -import java.io.IOException; -import java.time.LocalDateTime; -import java.util.*; -import java.util.stream.Collectors; - -/** - * MMR 알고리즘 기반 추천 전략 구현 - * - Elasticsearch k-NN 검색으로 초기 후보군 수집 - * - MMR 알고리즘으로 다양성 보장 - * - 읽은 글 제외 필터링 - * - 시간 감쇠 가중치 적용 (최신 게시글 우선) - */ -@Slf4j -@Service -@Transactional -@RequiredArgsConstructor -public class LlmRecommendationService implements RecommendationService { - - private final ElasticsearchClient elasticsearchClient; - private final UserProfileDocumentRepository userProfileDocumentRepository; - private final RecommendedPostRepository recommendedPostRepository; - private final RecommendationHistoryRepository recommendationHistoryRepository; - private final ReadPostRepository readPostRepository; - private final PostRepository postRepository; - private final MmrService mmrService; - private final TimeDecayStrategy timeDecayStrategy; - private final RecommendationProperties properties; - private final VectorQueryBuilder vectorQueryBuilder; - - private static final String POSTS_INDEX = "posts"; - private static final String TITLE_EMBEDDING_FIELD = "titleEmbedding"; - private static final String SUMMARY_EMBEDDING_FIELD = "summaryEmbedding"; - private static final String CONTENT_CHUNKS_FIELD = "contentChunks"; - private static final String CHUNK_EMBEDDING_FIELD = "embedding"; - - @Override - public int generateRecommendationsForUser(User user) { - log.info("사용자 {} 추천 생성 시작", user.getId()); - - // 1. 사용자 프로필 벡터 조회 - Optional profileOpt = userProfileDocumentRepository.findByUserId(user.getId()); - if (profileOpt.isEmpty() || profileOpt.get().getProfileVector() == null) { - log.warn("사용자 {}의 프로필 또는 벡터를 찾을 수 없음. 추천 생성 스킵.", user.getId()); - return 0; - } - - UserProfileDocument profile = profileOpt.get(); - float[] userProfileVector = profile.getProfileVector(); - - try { - // 2. k-NN 검색으로 초기 후보군 가져오기 - List candidates = searchCandidates(userProfileVector, user); - - if (candidates.isEmpty()) { - log.info("사용자 {}의 추천 후보군을 찾을 수 없음", user.getId()); - return 0; - } - - log.info("사용자 {} 추천 후보 {} 개 발견", user.getId(), candidates.size()); - - // 3. MMR 적용하여 최종 추천 선택 - List mmrResults = mmrService.applyMmr(candidates); - - // 4. 기존 추천을 이력으로 보관 (오늘 생성된 추천 포함) - List oldRecommendations = recommendedPostRepository.findByUserOrderByRankAsc(user); - - if (!oldRecommendations.isEmpty()) { - List histories = oldRecommendations.stream() - .map(RecommendationHistory::fromRecommendedPost) - .toList(); - recommendationHistoryRepository.saveAll(histories); - recommendedPostRepository.deleteByUser(user); - } - - // 5. 새 추천 저장 - List recommendations = new ArrayList<>(); - for (MmrResult result : mmrResults) { - Post post = postRepository.getReferenceById(result.getPostId()); - RecommendedPost recommendedPost = RecommendedPost.create( - user, - post, - result.getSimilarityScore(), - result.getMmrScore(), - result.getRank() - ); - recommendations.add(recommendedPost); - } - - recommendedPostRepository.saveAll(recommendations); - - log.info("사용자 {} 추천 생성 완료: {} 개", user.getId(), recommendations.size()); - - return recommendations.size(); - - } catch (Exception e) { - log.error("사용자 {} 추천 생성 실패", user.getId(), e); - throw new RuntimeException("추천 생성 중 오류가 발생했습니다.", e); - } - } - - /** - * 추천 생성 (평가 전용 - DB 저장 안함) - * @return 추천된 게시글 ID 리스트 - */ - public List generateRecommendationsForEvaluation(User user) { - // 1. 사용자 프로필 벡터 조회 - Optional profileOpt = userProfileDocumentRepository.findByUserId(user.getId()); - if (profileOpt.isEmpty() || profileOpt.get().getProfileVector() == null) { - log.warn("사용자 {}의 프로필 또는 벡터를 찾을 수 없음. 추천 생성 스킵.", user.getId()); - return Collections.emptyList(); - } - - float[] userProfileVector = profileOpt.get().getProfileVector(); - - try { - // 2. k-NN 검색으로 초기 후보군 가져오기 - List candidates = searchCandidates(userProfileVector, user); - - if (candidates.isEmpty()) { - log.debug("사용자 {}의 추천 후보군을 찾을 수 없음", user.getId()); - return Collections.emptyList(); - } - - // 3. MMR 적용하여 최종 추천 선택 - List mmrResults = mmrService.applyMmr(candidates); - - // 4. 추천된 게시글 ID 리스트 반환 (DB에 저장하지 않음) - return mmrResults.stream() - .map(MmrResult::getPostId) - .toList(); - - } catch (Exception e) { - log.error("사용자 {} 추천 생성 실패 (평가용)", user.getId(), e); - return Collections.emptyList(); - } - } - - /** - * 추천 생성 (평가 전용 - Train/Test Split 지원) - * 특정 읽은 글 목록(Train Set)만 제외하고 추천 생성 - * - * @param user 사용자 - * @param trainPostIds Train Set 게시글 ID 목록 (제외할 글) - * @return 추천된 게시글 ID 리스트 - */ - public List generateRecommendationsForEvaluation(User user, Set trainPostIds) { - // 1. 사용자 프로필 벡터 조회 - Optional profileOpt = userProfileDocumentRepository.findByUserId(user.getId()); - if (profileOpt.isEmpty() || profileOpt.get().getProfileVector() == null) { - log.warn("사용자 {}의 프로필 또는 벡터를 찾을 수 없음. 추천 생성 스킵.", user.getId()); - return Collections.emptyList(); - } - - float[] userProfileVector = profileOpt.get().getProfileVector(); - - try { - // 2. k-NN 검색으로 초기 후보군 가져오기 (Train Set만 제외) - List candidates = searchCandidatesWithCustomReadHistory(userProfileVector, user, trainPostIds); - - if (candidates.isEmpty()) { - log.debug("사용자 {}의 추천 후보군을 찾을 수 없음 (Train Set {} 개 제외)", user.getId(), trainPostIds.size()); - return Collections.emptyList(); - } - - // 3. MMR 적용하여 최종 추천 선택 - List mmrResults = mmrService.applyMmr(candidates); - - // 4. 추천된 게시글 ID 리스트 반환 - return mmrResults.stream() - .map(MmrResult::getPostId) - .toList(); - - } catch (Exception e) { - log.error("사용자 {} 추천 생성 실패 (Train/Test Split 평가용)", user.getId(), e); - return Collections.emptyList(); - } - } - - /** - * Elasticsearch k-NN 검색으로 초기 후보군 조회 (커스텀 읽은 글 목록) - * Train/Test Split 평가를 위해 Train Set만 제외 - */ - private List searchCandidatesWithCustomReadHistory( - float[] userProfileVector, - User user, - Set readPostIds) throws IOException { - - log.debug("사용자 {}의 읽은 게시글 {} 개 제외 (Train Set)", user.getId(), readPostIds.size()); - - // 가중치 가져오기 - RecommendationProperties.EmbeddingWeights weights = properties.getEmbeddingWeights(); - - // 랜덤 시드 생성 (현재 시간 기반) - long randomSeed = System.currentTimeMillis(); - double randomWeight = 0.2; // 랜덤 가중치 20% - - // k-NN 쿼리 (가중 평균: title + summary + content chunks + 랜덤 요소) - Query knnQuery = vectorQueryBuilder.createWeightedVectorQueryWithRandomness( - TITLE_EMBEDDING_FIELD, - SUMMARY_EMBEDDING_FIELD, - CONTENT_CHUNKS_FIELD, - CHUNK_EMBEDDING_FIELD, - userProfileVector, - weights.getTitle(), - weights.getSummary(), - weights.getContent(), - randomSeed, - randomWeight - ); - - log.debug("ES 쿼리 실행 (Train/Test Split) - 벡터 차원: {}, 가중치 [title:{}, summary:{}, content:{}]", - userProfileVector.length, weights.getTitle(), weights.getSummary(), weights.getContent()); - - SearchResponse response = elasticsearchClient.search(s -> s - .index(POSTS_INDEX) - .query(knnQuery) - .size(properties.getKnnSearchSize()) - , - PostDocument.class - ); - - // 결과를 MmrCandidate로 변환 (Train Set만 필터링) - return response.hits().hits().stream() - .filter(hit -> hit.source() != null) - .filter(hit -> !readPostIds.contains(hit.source().getPostId())) - .map(this::mapToMmrCandidate) - .filter(candidate -> candidate.getSummaryVector() != null) - .toList(); - } - - /** - * Elasticsearch k-NN 검색으로 초기 후보군 조회 - * - 이미 읽은 글 제외 - * - 랜덤 시드를 사용하여 매번 다른 후보군 생성 - */ - private List searchCandidates(float[] userProfileVector, User user) throws IOException { - // 이미 읽은 글 ID 목록 - Set readPostIds = readPostRepository.findRecentReadPostsByUserIdWithMinDuration(user.getId(), PageRequest.of(0, 1000)) - .stream() - .map(readPost -> readPost.getPost().getId()) - .collect(Collectors.toSet()); - - log.debug("사용자 {}의 읽은 게시글 {} 개 제외", user.getId(), readPostIds.size()); - - // 가중치 가져오기 - RecommendationProperties.EmbeddingWeights weights = properties.getEmbeddingWeights(); - - // 랜덤 시드 생성 (현재 시간 기반) - long randomSeed = System.currentTimeMillis(); - double randomWeight = 0.2; // 랜덤 가중치 20% - - // k-NN 쿼리 (가중 평균: title + summary + content chunks + 랜덤 요소) - Query knnQuery = vectorQueryBuilder.createWeightedVectorQueryWithRandomness( - TITLE_EMBEDDING_FIELD, - SUMMARY_EMBEDDING_FIELD, - CONTENT_CHUNKS_FIELD, - CHUNK_EMBEDDING_FIELD, - userProfileVector, - weights.getTitle(), - weights.getSummary(), - weights.getContent(), - randomSeed, - randomWeight - ); - - log.debug("ES 쿼리 실행 - 벡터 차원: {}, 가중치 [title:{}, summary:{}, content:{}], 랜덤시드: {}, 랜덤가중치: {}", - userProfileVector.length, weights.getTitle(), weights.getSummary(), weights.getContent(), - randomSeed, randomWeight); - - SearchResponse response = elasticsearchClient.search(s -> s - .index(POSTS_INDEX) - .query(knnQuery) - .size(properties.getKnnSearchSize()) - , - PostDocument.class - ); - - // 결과를 MmrCandidate로 변환 (읽은 글만 필터링) - return response.hits().hits().stream() - .filter(hit -> hit.source() != null) - .filter(hit -> !readPostIds.contains(hit.source().getPostId())) - .map(this::mapToMmrCandidate) - .filter(candidate -> candidate.getSummaryVector() != null) - .toList(); - } - - /** - * PostDocument를 MmrCandidate로 변환 - * 시간 감쇠 가중치를 유사도 점수에 적용 - */ - private MmrCandidate mapToMmrCandidate(Hit hit) { - PostDocument doc = hit.source(); - double score = Objects.requireNonNullElse(hit.score(), 0.0); - - // 시간 감쇠 가중치 적용 - double timeDecayWeight = timeDecayStrategy.calculateWeight(Objects.requireNonNull(doc).getPublishedAt()); - double adjustedScore = score * timeDecayWeight; - - log.trace("게시글 {} 점수 조정: 원본={}, 시간가중치={}, 최종={}", - doc.getPostId(), score, timeDecayWeight, adjustedScore); - - float[] titleVector = VectorUtil.convertToFloatArray(doc.getTitleEmbedding()); - float[] summaryVector = VectorUtil.convertToFloatArray(doc.getSummaryEmbedding()); - - return MmrCandidate.builder() - .postId(doc.getPostId()) - .titleVector(titleVector) - .summaryVector(summaryVector) - .similarityScore(adjustedScore) - .build(); - } -} +package com.techfork.domain.recommendation.service; + +import co.elastic.clients.elasticsearch.ElasticsearchClient; +import co.elastic.clients.elasticsearch._types.FieldValue; +import co.elastic.clients.elasticsearch._types.KnnSearch; +import co.elastic.clients.elasticsearch._types.query_dsl.Query; +import co.elastic.clients.elasticsearch.core.SearchResponse; +import co.elastic.clients.elasticsearch.core.search.Hit; +import com.techfork.domain.activity.repository.ReadPostRepository; +import com.techfork.global.elasticsearch.query.VectorQueryBuilder; +import com.techfork.domain.post.document.PostDocument; +import com.techfork.domain.post.entity.Post; +import com.techfork.domain.post.repository.PostRepository; +import com.techfork.domain.recommendation.config.RecommendationProperties; +import com.techfork.domain.recommendation.entity.RecommendedPost; +import com.techfork.domain.recommendation.entity.RecommendationHistory; +import com.techfork.domain.recommendation.repository.RecommendedPostRepository; +import com.techfork.domain.recommendation.repository.RecommendationHistoryRepository; +import com.techfork.domain.recommendation.service.MmrService.MmrCandidate; +import com.techfork.domain.recommendation.service.MmrService.MmrResult; +import com.techfork.domain.user.document.UserProfileDocument; +import com.techfork.domain.user.entity.User; +import com.techfork.domain.user.repository.UserProfileDocumentRepository; +import com.techfork.global.util.TimeDecayStrategy; +import com.techfork.global.util.VectorUtil; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.springframework.data.domain.PageRequest; +import org.springframework.stereotype.Service; +import org.springframework.transaction.annotation.Transactional; + +import java.io.IOException; +import java.util.*; +import java.util.stream.Collectors; + +/** + * MMR 알고리즘 기반 추천 전략 구현 + * - Elasticsearch k-NN 검색으로 초기 후보군 수집 + * - MMR 알고리즘으로 다양성 보장 + * - 읽은 글 제외 필터링 (Pre-filtering) + * - 시간 감쇠 가중치 적용 (최신 게시글 우선) + */ +@Slf4j +@Service +@Transactional +@RequiredArgsConstructor +public class LlmRecommendationService implements RecommendationService { + + private final ElasticsearchClient elasticsearchClient; + private final UserProfileDocumentRepository userProfileDocumentRepository; + private final RecommendedPostRepository recommendedPostRepository; + private final RecommendationHistoryRepository recommendationHistoryRepository; + private final ReadPostRepository readPostRepository; + private final PostRepository postRepository; + private final MmrService mmrService; + private final TimeDecayStrategy timeDecayStrategy; + private final RecommendationProperties properties; + private final VectorQueryBuilder vectorQueryBuilder; + + private static final String POSTS_INDEX = "posts"; + private static final String TITLE_EMBEDDING_FIELD = "titleEmbedding"; + private static final String SUMMARY_EMBEDDING_FIELD = "summaryEmbedding"; + private static final String CONTENT_CHUNKS_EMBEDDING_FIELD = "contentChunks.embedding"; + + @Override + public int generateRecommendationsForUser(User user) { + log.info("사용자 {} 추천 생성 시작", user.getId()); + + // 1. 사용자 프로필 벡터 조회 + Optional profileOpt = userProfileDocumentRepository.findByUserId(user.getId()); + if (profileOpt.isEmpty() || profileOpt.get().getProfileVector() == null) { + log.warn("사용자 {}의 프로필 또는 벡터를 찾을 수 없음. 추천 생성 스킵.", user.getId()); + return 0; + } + + UserProfileDocument profile = profileOpt.get(); + float[] userProfileVector = profile.getProfileVector(); + + try { + // 2. k-NN 검색으로 초기 후보군 가져오기 + List candidates = searchCandidates(userProfileVector, user); + + if (candidates.isEmpty()) { + log.info("사용자 {}의 추천 후보군을 찾을 수 없음", user.getId()); + return 0; + } + + log.info("사용자 {} 추천 후보 {} 개 발견", user.getId(), candidates.size()); + + // 3. MMR 적용하여 최종 추천 선택 + List mmrResults = mmrService.applyMmr(candidates); + + // 4. 기존 추천을 이력으로 보관 (오늘 생성된 추천 포함) + List oldRecommendations = recommendedPostRepository.findByUserOrderByRankAsc(user); + + if (!oldRecommendations.isEmpty()) { + List histories = oldRecommendations.stream() + .map(RecommendationHistory::fromRecommendedPost) + .toList(); + recommendationHistoryRepository.saveAll(histories); + recommendedPostRepository.deleteByUser(user); + } + + // 5. 새 추천 저장 + List recommendations = new ArrayList<>(); + for (MmrResult result : mmrResults) { + Post post = postRepository.getReferenceById(result.getPostId()); + RecommendedPost recommendedPost = RecommendedPost.create( + user, + post, + result.getSimilarityScore(), + result.getMmrScore(), + result.getRank() + ); + recommendations.add(recommendedPost); + } + + recommendedPostRepository.saveAll(recommendations); + + log.info("사용자 {} 추천 생성 완료: {} 개", user.getId(), recommendations.size()); + + return recommendations.size(); + + } catch (Exception e) { + log.error("사용자 {} 추천 생성 실패", user.getId(), e); + throw new RuntimeException("추천 생성 중 오류가 발생했습니다.", e); + } + } + + /** + * 추천 생성 (평가 전용 - Train/Test Split 지원) + * 특정 읽은 글 목록(Train Set)만 제외하고 추천 생성 + * + * @param user 사용자 + * @param trainPostIds Train Set 게시글 ID 목록 (제외할 글) + * @return 추천된 게시글 ID 리스트 + */ + public List generateRecommendationsForEvaluation(User user, Set trainPostIds) { + // 1. 사용자 프로필 벡터 조회 + Optional profileOpt = userProfileDocumentRepository.findByUserId(user.getId()); + if (profileOpt.isEmpty() || profileOpt.get().getProfileVector() == null) { + log.warn("사용자 {}의 프로필 또는 벡터를 찾을 수 없음. 추천 생성 스킵.", user.getId()); + return Collections.emptyList(); + } + + float[] userProfileVector = profileOpt.get().getProfileVector(); + + try { + // 2. k-NN 검색으로 초기 후보군 가져오기 (Train Set만 제외) + List candidates = searchCandidatesWithCustomReadHistory(userProfileVector, user, trainPostIds); + + if (candidates.isEmpty()) { + log.debug("사용자 {}의 추천 후보군을 찾을 수 없음 (Train Set {} 개 제외)", user.getId(), trainPostIds.size()); + return Collections.emptyList(); + } + + // 3. MMR 적용하여 최종 추천 선택 + List mmrResults = mmrService.applyMmr(candidates); + + // 4. 추천된 게시글 ID 리스트 반환 + return mmrResults.stream() + .map(MmrResult::getPostId) + .toList(); + + } catch (Exception e) { + log.error("사용자 {} 추천 생성 실패 (Train/Test Split 평가용)", user.getId(), e); + return Collections.emptyList(); + } + } + + /** + * Elasticsearch 네이티브 k-NN 검색으로 초기 후보군 조회 (커스텀 읽은 글 목록) + * Train/Test Split 평가를 위해 Train Set만 제외 + */ + private List searchCandidatesWithCustomReadHistory( + float[] userProfileVector, + User user, + Set readPostIds) throws IOException { + + log.debug("사용자 {}의 읽은 게시글 {} 개 제외 (Train Set)", user.getId(), readPostIds.size()); + + // 가중치 가져오기 + RecommendationProperties.EmbeddingWeights weights = properties.getEmbeddingWeights(); + + // 랜덤 시드 생성 (현재 시간 기반) + long randomSeed = System.currentTimeMillis(); + double randomWeight = 0.0; + + // 1. 읽은 글 제외 필터 쿼리 생성 (Pre-filtering) + Query filterQuery = createExcludeFilter(readPostIds); + + // 2. 네이티브 k-NN 검색 객체 리스트 생성 (Title + Summary + Content) + List knnSearches = vectorQueryBuilder.createKnnSearches( + TITLE_EMBEDDING_FIELD, + SUMMARY_EMBEDDING_FIELD, + CONTENT_CHUNKS_EMBEDDING_FIELD, + userProfileVector, + weights.getTitle(), + weights.getSummary(), + weights.getContent(), + properties.getKnnSearchSize(), + properties.getNumCandidates(), + filterQuery + ); + + // 3. 랜덤 요소 추가 (function_score) + Query randomQuery = vectorQueryBuilder.createRandomScoreQuery(randomSeed, randomWeight); + + log.debug("ES k-NN 검색 실행 (Train/Test Split) - 가중치 [title:{}, summary:{}], 랜덤가중치: {}", + weights.getTitle(), weights.getSummary(), randomWeight); + + long startTime = System.currentTimeMillis(); + SearchResponse response = elasticsearchClient.search(s -> s + .index(POSTS_INDEX) + .knn(knnSearches) // k-NN 검색 (관련성 + 필터링) + .query(randomQuery) // 랜덤 점수 추가 + .size(properties.getKnnSearchSize()) + , + PostDocument.class + ); + long duration = System.currentTimeMillis() - startTime; + log.info("추천 후보군 검색 완료 (Evaluation): {} 개, 소요 시간: {}ms", response.hits().hits().size(), duration); + + // 결과를 MmrCandidate로 변환 + return response.hits().hits().stream() + .filter(hit -> hit.source() != null) + .map(this::mapToMmrCandidate) + .filter(candidate -> candidate.getSummaryVector() != null) + .toList(); + } + + /** + * Elasticsearch 네이티브 k-NN 검색으로 초기 후보군 조회 + * - 이미 읽은 글 제외 + * - 랜덤 시드를 사용하여 매번 다른 후보군 생성 + */ + private List searchCandidates(float[] userProfileVector, User user) throws IOException { + // 이미 읽은 글 ID 목록 + Set readPostIds = readPostRepository.findRecentReadPostsByUserIdWithMinDuration(user.getId(), PageRequest.of(0, 1000)) + .stream() + .map(readPost -> readPost.getPost().getId()) + .collect(Collectors.toSet()); + + log.debug("사용자 {}의 읽은 게시글 {} 개 제외", user.getId(), readPostIds.size()); + + // 가중치 가져오기 + RecommendationProperties.EmbeddingWeights weights = properties.getEmbeddingWeights(); + + // 랜덤 시드 생성 (현재 시간 기반) + long randomSeed = System.currentTimeMillis(); + double randomWeight = 0.0; // 랜덤 가중치 20% + + // 1. 읽은 글 제외 필터 쿼리 생성 (Pre-filtering) + Query filterQuery = createExcludeFilter(readPostIds); + + // 2. 네이티브 k-NN 검색 객체 리스트 생성 (Title + Summary + Content) + List knnSearches = vectorQueryBuilder.createKnnSearches( + TITLE_EMBEDDING_FIELD, + SUMMARY_EMBEDDING_FIELD, + CONTENT_CHUNKS_EMBEDDING_FIELD, + userProfileVector, + weights.getTitle(), + weights.getSummary(), + weights.getContent(), + properties.getKnnSearchSize(), + properties.getNumCandidates(), + filterQuery + ); + + // 3. 랜덤 요소 추가 (function_score) + Query randomQuery = vectorQueryBuilder.createRandomScoreQuery(randomSeed, randomWeight); + + log.debug("ES k-NN 검색 실행 - 가중치 [title:{}, summary:{}], 랜덤시드: {}, 랜덤가중치: {}", + weights.getTitle(), weights.getSummary(), randomSeed, randomWeight); + + long startTime = System.currentTimeMillis(); + SearchResponse response = elasticsearchClient.search(s -> s + .index(POSTS_INDEX) + .knn(knnSearches) // k-NN 검색 (관련성 + 필터링) + .query(randomQuery) // 랜덤 점수 추가 + .size(properties.getKnnSearchSize()) + , + PostDocument.class + ); + long duration = System.currentTimeMillis() - startTime; + log.info("추천 후보군 검색 완료: {} 개, 소요 시간: {}ms", response.hits().hits().size(), duration); + + // 결과를 MmrCandidate로 변환 + return response.hits().hits().stream() + .filter(hit -> hit.source() != null) + .map(this::mapToMmrCandidate) + .filter(candidate -> candidate.getSummaryVector() != null) + .toList(); + } + + /** + * 읽은 글 제외를 위한 필터 쿼리 생성 + */ + private Query createExcludeFilter(Set readPostIds) { + if (readPostIds == null || readPostIds.isEmpty()) { + return null; + } + + List excludeValues = readPostIds.stream() + .map(FieldValue::of) + .toList(); + + return Query.of(q -> q + .bool(b -> b + .mustNot(mn -> mn + .terms(t -> t + .field("postId") + .terms(v -> v.value(excludeValues)) + ) + ) + ) + ); + } + + /** + * PostDocument를 MmrCandidate로 변환 + * 시간 감쇠 가중치를 유사도 점수에 적용 + */ + private MmrCandidate mapToMmrCandidate(Hit hit) { + PostDocument doc = hit.source(); + double score = Objects.requireNonNullElse(hit.score(), 0.0); + + // 시간 감쇠 가중치 적용 + double timeDecayWeight = timeDecayStrategy.calculateWeight(Objects.requireNonNull(doc).getPublishedAt()); + double adjustedScore = score * timeDecayWeight; + + log.trace("게시글 {} 점수 조정: 원본={}, 시간가중치={}, 최종={}", + doc.getPostId(), score, timeDecayWeight, adjustedScore); + + float[] titleVector = VectorUtil.convertToFloatArray(doc.getTitleEmbedding()); + float[] summaryVector = VectorUtil.convertToFloatArray(doc.getSummaryEmbedding()); + + return MmrCandidate.builder() + .postId(doc.getPostId()) + .titleVector(titleVector) + .summaryVector(summaryVector) + .similarityScore(adjustedScore) + .build(); + } +} diff --git a/src/main/java/com/techfork/domain/recommendation/service/MmrService.java b/src/main/java/com/techfork/domain/recommendation/service/MmrService.java index ae95489..af72fa6 100644 --- a/src/main/java/com/techfork/domain/recommendation/service/MmrService.java +++ b/src/main/java/com/techfork/domain/recommendation/service/MmrService.java @@ -165,11 +165,17 @@ private double calculateWeightedSimilarity(MmrCandidate candidate1, MmrCandidate summarySim = VectorUtil.cosineSimilarity(candidate1.getSummaryVector(), candidate2.getSummaryVector()); } + // 코사인 유사도(-1.0 ~ 1.0)를 0.0 ~ 1.0 범위로 정규화하여 ES 점수와 스케일을 맞춤 + double normalizedTitleSim = (titleSim + 1) / 2.0; + double normalizedSummarySim = (summarySim + 1) / 2.0; + // 가중 평균 (제목 + 요약만, 콘텐츠는 제외) double titleWeight = weights.getTitle(); double summaryWeight = weights.getSummary(); double totalWeight = titleWeight + summaryWeight; - return (titleWeight * titleSim + summaryWeight * summarySim) / totalWeight; + if (totalWeight == 0) return 0.0; + + return (titleWeight * normalizedTitleSim + summaryWeight * normalizedSummarySim) / totalWeight; } } diff --git a/src/main/java/com/techfork/domain/recommendation_quality/ImprovedRecommendationTestCase.java b/src/main/java/com/techfork/domain/recommendation_quality/ImprovedRecommendationTestCase.java deleted file mode 100644 index d66c737..0000000 --- a/src/main/java/com/techfork/domain/recommendation_quality/ImprovedRecommendationTestCase.java +++ /dev/null @@ -1,66 +0,0 @@ -package com.techfork.domain.recommendation_quality; - -import lombok.AllArgsConstructor; -import lombok.Builder; -import lombok.Data; -import lombok.NoArgsConstructor; - -import java.util.List; -import java.util.Set; -import java.util.stream.Collectors; - -/** - * Train/Test Split 기반 개선된 추천 시스템 테스트 케이스 - * - * 기존 방식의 문제: - * 1. Ground Truth가 문자열 매칭 기반 (추천 시스템은 벡터 유사도 기반) - * 2. Recall 분모가 너무 커서(100개) 지표가 낮게 나옴 - * - * 개선 방식: - * 1. 읽은 글을 8:2로 분할 (Train/Test) - * 2. Test Set을 Ground Truth로 사용 (실제로 읽은 글 = 관심있는 글) - * 3. 적절한 Recall 분모 (Test Set 크기 = 20개 정도) - */ -@Data -@Builder -@NoArgsConstructor -@AllArgsConstructor -public class ImprovedRecommendationTestCase { - - /** - * 사용자 ID - */ - private Long userId; - - /** - * 사용자 관심사 - */ - private List interests; - - /** - * Train/Test 분할 결과 - */ - private TrainTestSplit trainTestSplit; - - /** - * Test Set을 Ground Truth로 반환 (Recall 계산용) - */ - public Set getGroundTruthPostIds() { - return trainTestSplit.getTestPostIds().stream() - .collect(Collectors.toSet()); - } - - /** - * Train Set 반환 (사용자 프로필 생성용) - */ - public List getTrainPostIds() { - return trainTestSplit.getTrainPostIds(); - } - - /** - * Test Set 반환 (평가용) - */ - public List getTestPostIds() { - return trainTestSplit.getTestPostIds(); - } -} diff --git a/src/main/java/com/techfork/domain/recommendation_quality/RecommendationTestCase.java b/src/main/java/com/techfork/domain/recommendation_quality/RecommendationTestCase.java deleted file mode 100644 index 2e0a898..0000000 --- a/src/main/java/com/techfork/domain/recommendation_quality/RecommendationTestCase.java +++ /dev/null @@ -1,53 +0,0 @@ -package com.techfork.domain.recommendation_quality; - -import lombok.AllArgsConstructor; -import lombok.Builder; -import lombok.Data; -import lombok.NoArgsConstructor; - -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.stream.Collectors; - -/** - * 추천 시스템 평가를 위한 테스트 케이스 - * 단순 구조: 사용자 + 읽은 글 + Ground Truth - */ -@Data -@Builder -@NoArgsConstructor -@AllArgsConstructor -public class RecommendationTestCase { - - /** - * 사용자 ID - */ - private Long userId; - - /** - * 사용자 관심사 - */ - private List interests; - - /** - * 읽은 글 이력 (사용자 프로필 벡터 생성에 사용) - */ - private List readPostIds; - - /** - * Ground Truth: 실제로 관심있을 만한 게시글과 관련도 - * 관련도 점수: 5(매우 관련), 4(관련), 3(보통), 2(조금), 1(약간) - */ - private Map groundTruthScores; - - /** - * Ground Truth에서 관련도 1 이상인 게시글 ID (Recall 계산용) - */ - public Set getRelevantPostIds() { - return groundTruthScores.entrySet().stream() - .filter(entry -> entry.getValue() > 0) - .map(Map.Entry::getKey) - .collect(Collectors.toSet()); - } -} diff --git a/src/main/java/com/techfork/domain/recommendation_quality/TrainTestSplit.java b/src/main/java/com/techfork/domain/recommendation_quality/TrainTestSplit.java deleted file mode 100644 index 6946985..0000000 --- a/src/main/java/com/techfork/domain/recommendation_quality/TrainTestSplit.java +++ /dev/null @@ -1,45 +0,0 @@ -package com.techfork.domain.recommendation_quality; - -import lombok.AllArgsConstructor; -import lombok.Builder; -import lombok.Data; -import lombok.NoArgsConstructor; - -import java.util.List; - -/** - * 사용자 읽기 이력의 Train/Test 분할 결과 - * Train: 사용자 프로필 생성에 사용 - * Test: 평가 Ground Truth로 사용 - */ -@Data -@Builder -@NoArgsConstructor -@AllArgsConstructor -public class TrainTestSplit { - - /** - * Train 세트: 사용자 프로필 생성에 사용될 게시글 ID 목록 (80%) - */ - private List trainPostIds; - - /** - * Test 세트: 평가 Ground Truth로 사용될 게시글 ID 목록 (20%) - * 추천 시스템이 이 글들을 상위권에 추천했는지 평가 - */ - private List testPostIds; - - /** - * Train 세트 크기 - */ - public int getTrainSize() { - return trainPostIds.size(); - } - - /** - * Test 세트 크기 - */ - public int getTestSize() { - return testPostIds.size(); - } -} diff --git a/src/main/java/com/techfork/global/elasticsearch/query/VectorQueryBuilder.java b/src/main/java/com/techfork/global/elasticsearch/query/VectorQueryBuilder.java index ad81a13..c435993 100644 --- a/src/main/java/com/techfork/global/elasticsearch/query/VectorQueryBuilder.java +++ b/src/main/java/com/techfork/global/elasticsearch/query/VectorQueryBuilder.java @@ -1,93 +1,50 @@ -package com.techfork.global.elasticsearch.query; - -import co.elastic.clients.elasticsearch._types.query_dsl.ChildScoreMode; -import co.elastic.clients.elasticsearch._types.query_dsl.Query; - -/** - * Elasticsearch 벡터 검색 쿼리 빌더 인터페이스 - * script_score를 사용한 코사인 유사도 검색 쿼리 생성 - */ -public interface VectorQueryBuilder { - - /** - * 여러 벡터 필드를 가중치 합산하는 bool should 쿼리 생성 - * (title, summary, content chunks 등을 결합) - * - * @param titleField 제목 벡터 필드명 - * @param summaryField 요약 벡터 필드명 - * @param contentChunksPath content chunks nested 경로 - * @param chunkEmbeddingField chunk 임베딩 필드명 - * @param queryVector 쿼리 벡터 - * @param titleWeight 제목 가중치 - * @param summaryWeight 요약 가중치 - * @param contentWeight 컨텐츠 가중치 - * @return 가중치 적용된 복합 쿼리 - */ - Query createWeightedVectorQuery( - String titleField, - String summaryField, - String contentChunksPath, - String chunkEmbeddingField, - float[] queryVector, - float titleWeight, - float summaryWeight, - float contentWeight - ); - - /** - * 랜덤 시드를 포함한 가중치 벡터 쿼리 생성 - * (추천 재생성 시 다양성 확보용) - * - * @param titleField 제목 벡터 필드명 - * @param summaryField 요약 벡터 필드명 - * @param contentChunksPath content chunks nested 경로 - * @param chunkEmbeddingField chunk 임베딩 필드명 - * @param queryVector 쿼리 벡터 - * @param titleWeight 제목 가중치 - * @param summaryWeight 요약 가중치 - * @param contentWeight 컨텐츠 가중치 - * @param randomSeed 랜덤 시드 - * @param randomWeight 랜덤 가중치 (0.0~1.0, 보통 0.1~0.3) - * @return 랜덤 요소가 포함된 복합 쿼리 - */ - Query createWeightedVectorQueryWithRandomness( - String titleField, - String summaryField, - String contentChunksPath, - String chunkEmbeddingField, - float[] queryVector, - float titleWeight, - float summaryWeight, - float contentWeight, - long randomSeed, - double randomWeight - ); - - /** - * 단일 필드에 대한 script_score 쿼리 생성 - * - * @param fieldName 벡터 필드명 - * @param queryVector 쿼리 벡터 - * @param boost 부스트 가중치 - * @return script_score 쿼리 - */ - Query createScriptScoreQuery(String fieldName, float[] queryVector, float boost); - - /** - * nested 필드에 대한 script_score 쿼리 생성 - * - * @param nestedPath nested 경로 - * @param vectorFieldName 벡터 필드명 (nested 내부) - * @param queryVector 쿼리 벡터 - * @param boost 부스트 가중치 - * @param scoreMode nested 스코어 모드 (Max, Avg 등) - * @return nested script_score 쿼리 - */ - Query createNestedScriptScoreQuery( - String nestedPath, - String vectorFieldName, - float[] queryVector, - float boost, - ChildScoreMode scoreMode - ); -} +package com.techfork.global.elasticsearch.query; + +import co.elastic.clients.elasticsearch._types.KnnSearch; +import co.elastic.clients.elasticsearch._types.query_dsl.Query; +import java.util.List; + +/** + * Elasticsearch 벡터 검색 쿼리 빌더 인터페이스 + * 네이티브 k-NN 검색 및 하이브리드 검색을 위한 쿼리 생성 제공 + */ +public interface VectorQueryBuilder { + + /** + * 네이티브 k-NN 검색 객체 리스트 생성 + * (title, summary, content 필드에 대한 k-NN 검색) + * + * @param titleField 제목 벡터 필드명 + * @param summaryField 요약 벡터 필드명 + * @param contentField 컨텐츠 벡터 필드명 (Nested 경로 포함) + * @param queryVector 쿼리 벡터 + * @param titleWeight 제목 가중치 + * @param summaryWeight 요약 가중치 + * @param contentWeight 컨텐츠 가중치 + * @param k 검색할 이웃 수 + * @param numCandidates 후보군 수 + * @param filter 사전 필터링 쿼리 (null 가능) + * @return KnnSearch 객체 리스트 + */ + List createKnnSearches( + String titleField, + String summaryField, + String contentField, + float[] queryVector, + float titleWeight, + float summaryWeight, + float contentWeight, + int k, + int numCandidates, + Query filter + ); + + /** + * 랜덤 점수를 위한 function_score 쿼리 생성 + * + * @param randomSeed 랜덤 시드 + * @param randomWeight 랜덤 가중치 + * @return function_score 쿼리 + */ + Query createRandomScoreQuery(long randomSeed, double randomWeight); +} \ No newline at end of file diff --git a/src/main/java/com/techfork/global/elasticsearch/query/VectorSearchQueryBuilder.java b/src/main/java/com/techfork/global/elasticsearch/query/VectorSearchQueryBuilder.java index 41b0548..3a978bf 100644 --- a/src/main/java/com/techfork/global/elasticsearch/query/VectorSearchQueryBuilder.java +++ b/src/main/java/com/techfork/global/elasticsearch/query/VectorSearchQueryBuilder.java @@ -1,136 +1,101 @@ -package com.techfork.global.elasticsearch.query; - -import co.elastic.clients.elasticsearch._types.query_dsl.ChildScoreMode; -import co.elastic.clients.elasticsearch._types.query_dsl.Query; -import co.elastic.clients.json.JsonData; -import lombok.AccessLevel; -import lombok.NoArgsConstructor; -import org.springframework.stereotype.Component; - -/** - * Elasticsearch 벡터 검색 쿼리 빌더 구현체 - * script_score를 사용한 코사인 유사도 검색 쿼리 생성 - */ -@Component -@NoArgsConstructor(access = AccessLevel.PRIVATE) -public class VectorSearchQueryBuilder implements VectorQueryBuilder { - - private static final String COSINE_SIMILARITY_SCRIPT_TEMPLATE = "cosineSimilarity(params.query_vector, '%s') + 1.0"; - private static final String QUERY_VECTOR_PARAM = "query_vector"; - - @Override - public Query createWeightedVectorQuery( - String titleField, - String summaryField, - String contentChunksPath, - String chunkEmbeddingField, - float[] queryVector, - float titleWeight, - float summaryWeight, - float contentWeight - ) { - Query titleQuery = createScriptScoreQuery(titleField, queryVector, titleWeight); - Query summaryQuery = createScriptScoreQuery(summaryField, queryVector, summaryWeight); - Query chunkQuery = createNestedScriptScoreQuery( - contentChunksPath, - chunkEmbeddingField, - queryVector, - contentWeight, - ChildScoreMode.Max - ); - - return Query.of(q -> q - .bool(b -> b - .should(titleQuery) - .should(summaryQuery) - .should(chunkQuery) - ) - ); - } - - @Override - public Query createWeightedVectorQueryWithRandomness( - String titleField, - String summaryField, - String contentChunksPath, - String chunkEmbeddingField, - float[] queryVector, - float titleWeight, - float summaryWeight, - float contentWeight, - long randomSeed, - double randomWeight - ) { - // 기본 벡터 쿼리 생성 - Query baseQuery = createWeightedVectorQuery( - titleField, - summaryField, - contentChunksPath, - chunkEmbeddingField, - queryVector, - titleWeight, - summaryWeight, - contentWeight - ); - - // function_score로 랜덤 요소 추가 - return Query.of(q -> q - .functionScore(fs -> fs - .query(baseQuery) - .functions(fn -> fn - .randomScore(rs -> rs - .seed(String.valueOf(randomSeed)) - .field("_seq_no") - ) - .weight(randomWeight) - ) - .boostMode(co.elastic.clients.elasticsearch._types.query_dsl.FunctionBoostMode.Sum) - ) - ); - } - - @Override - public Query createScriptScoreQuery(String fieldName, float[] queryVector, float boost) { - String script = String.format(COSINE_SIMILARITY_SCRIPT_TEMPLATE, fieldName); - - return Query.of(q -> q - .scriptScore(ss -> ss - .query(mq -> mq.matchAll(m -> m)) - .script(s -> s - .source(script) - .params(QUERY_VECTOR_PARAM, JsonData.of(queryVector)) - ) - .boost(boost) - ) - ); - } - - @Override - public Query createNestedScriptScoreQuery( - String nestedPath, - String vectorFieldName, - float[] queryVector, - float boost, - ChildScoreMode scoreMode - ) { - String fullPath = nestedPath + "." + vectorFieldName; - String script = String.format(COSINE_SIMILARITY_SCRIPT_TEMPLATE, fullPath); - - return Query.of(q -> q - .nested(n -> n - .path(nestedPath) - .scoreMode(ChildScoreMode.Max) - .query(nq -> nq - .scriptScore(ss -> ss - .query(mq -> mq.matchAll(m -> m)) - .script(s -> s - .source(script) - .params(QUERY_VECTOR_PARAM, JsonData.of(queryVector)) - ) - ) - ) - .boost(boost) - ) - ); - } -} +package com.techfork.global.elasticsearch.query; + +import co.elastic.clients.elasticsearch._types.KnnSearch; +import co.elastic.clients.elasticsearch._types.query_dsl.FunctionBoostMode; +import co.elastic.clients.elasticsearch._types.query_dsl.Query; +import lombok.AccessLevel; +import lombok.NoArgsConstructor; +import org.springframework.stereotype.Component; + +import java.util.ArrayList; +import java.util.List; + +/** + * Elasticsearch 벡터 검색 쿼리 빌더 구현체 + * 네이티브 k-NN 검색 및 하이브리드 검색을 위한 쿼리 생성 제공 + */ +@Component +@NoArgsConstructor(access = AccessLevel.PRIVATE) +public class VectorSearchQueryBuilder implements VectorQueryBuilder { + + @Override + public List createKnnSearches( + String titleField, + String summaryField, + String contentField, + float[] queryVector, + float titleWeight, + float summaryWeight, + float contentWeight, + int k, + int numCandidates, + Query filter + ) { + List knnSearches = new ArrayList<>(); + List vectorList = new ArrayList<>(); + for (float v : queryVector) { + vectorList.add(v); + } + + if (titleWeight > 0) { + knnSearches.add(KnnSearch.of(ks -> { + ks.field(titleField) + .queryVector(vectorList) + .k(k) + .numCandidates(numCandidates) + .boost(titleWeight); + if (filter != null) { + ks.filter(filter); + } + return ks; + })); + } + + if (summaryWeight > 0) { + knnSearches.add(KnnSearch.of(ks -> { + ks.field(summaryField) + .queryVector(vectorList) + .k(k) + .numCandidates(numCandidates) + .boost(summaryWeight); + if (filter != null) { + ks.filter(filter); + } + return ks; + })); + } + + if (contentWeight > 0 && contentField != null) { + knnSearches.add(KnnSearch.of(ks -> { + ks.field(contentField) + .queryVector(vectorList) + .k(k) + .numCandidates(numCandidates) + .boost(contentWeight); + if (filter != null) { + ks.filter(filter); + } + return ks; + })); + } + + return knnSearches; + } + + @Override + public Query createRandomScoreQuery(long randomSeed, double randomWeight) { + return Query.of(q -> q + .functionScore(fs -> fs + .query(mq -> mq.matchAll(m -> m)) + .functions(fn -> fn + .randomScore(rs -> rs + .seed(String.valueOf(randomSeed)) + .field("_seq_no") + ) + .weight(randomWeight) + ) + .boostMode(FunctionBoostMode.Sum) + ) + ); + } +} \ No newline at end of file diff --git a/src/main/java/com/techfork/global/util/VectorUtil.java b/src/main/java/com/techfork/global/util/VectorUtil.java index 3dd10d6..42996f8 100644 --- a/src/main/java/com/techfork/global/util/VectorUtil.java +++ b/src/main/java/com/techfork/global/util/VectorUtil.java @@ -31,7 +31,7 @@ public static float[] convertToFloatArray(List embedding) { } /** - * 코사인 유사도 계산 + * 코사인 유사도 계산 (float[] vs float[]) * 두 벡터 간의 코사인 각도를 기반으로 유사도 측정 (-1.0 ~ 1.0) * * @param vectorA 첫 번째 벡터 @@ -59,4 +59,70 @@ public static double cosineSimilarity(float[] vectorA, float[] vectorB) { return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB)); } + + /** + * 코사인 유사도 계산 (float[] vs List) + * + * @param vector1 첫 번째 벡터 (float 배열) + * @param vector2 두 번째 벡터 (Float 리스트) + * @return 코사인 유사도 (0.0 ~ 1.0) + */ + public static double cosineSimilarity(float[] vector1, List vector2) { + if (vector1 == null || vector2 == null || vector1.length == 0 || vector2.isEmpty()) { + return 0.0; + } + + if (vector1.length != vector2.size()) { + return 0.0; + } + + double dotProduct = 0.0; + double norm1 = 0.0; + double norm2 = 0.0; + + for (int i = 0; i < vector1.length; i++) { + dotProduct += vector1[i] * vector2.get(i); + norm1 += vector1[i] * vector1[i]; + norm2 += vector2.get(i) * vector2.get(i); + } + + if (norm1 == 0.0 || norm2 == 0.0) { + return 0.0; + } + + return dotProduct / (Math.sqrt(norm1) * Math.sqrt(norm2)); + } + + /** + * 코사인 유사도 계산 (List vs List) + * + * @param vector1 첫 번째 벡터 (Float 리스트) + * @param vector2 두 번째 벡터 (Float 리스트) + * @return 코사인 유사도 (0.0 ~ 1.0) + */ + public static double cosineSimilarity(List vector1, List vector2) { + if (vector1 == null || vector2 == null || vector1.isEmpty() || vector2.isEmpty()) { + return 0.0; + } + + if (vector1.size() != vector2.size()) { + return 0.0; + } + + double dotProduct = 0.0; + double norm1 = 0.0; + double norm2 = 0.0; + + for (int i = 0; i < vector1.size(); i++) { + dotProduct += vector1.get(i) * vector2.get(i); + norm1 += vector1.get(i) * vector1.get(i); + norm2 += vector2.get(i) * vector2.get(i); + } + + if (norm1 == 0.0 || norm2 == 0.0) { + return 0.0; + } + + return dotProduct / (Math.sqrt(norm1) * Math.sqrt(norm2)); + } } diff --git a/src/main/resources/application.yml b/src/main/resources/application.yml index aa57098..cd30af7 100644 --- a/src/main/resources/application.yml +++ b/src/main/resources/application.yml @@ -3,6 +3,8 @@ spring: name: TechFork profiles: default: local-tunnel + config: + import: optional:file:.env ai: anthropic: api-key: ${ANTHROPIC_API_KEY} @@ -95,13 +97,13 @@ recommendation: knn-search-size: 100 num-candidates: 200 mmr-final-size: 30 - lambda: 0.3 + lambda: 0.95 active-user-hours: 24 # 임베딩 가중치 설정 (합계 1.0) embedding-weights: - title: 0.2 # 제목 중요도 20% - summary: 0.2 # 요약 중요도 20% - content: 0.6 # 콘텐츠 청크 중요도 60% + title: 0.5 # 제목 중요도 50% + summary: 0.5 # 요약 중요도 50% + content: 0.0 # 콘텐츠 청크 중요도 0% (제외) # 시간 감쇠 가중치 설정 time-decay: days-7: 1.3 # 최근 7일: +30% diff --git a/src/test/java/com/techfork/domain/recommendation/LambdaOptimizationTest.java b/src/test/java/com/techfork/domain/recommendation/LambdaOptimizationTest.java deleted file mode 100644 index b76b0ea..0000000 --- a/src/test/java/com/techfork/domain/recommendation/LambdaOptimizationTest.java +++ /dev/null @@ -1,92 +0,0 @@ -package com.techfork.domain.recommendation; - -import com.techfork.domain.user.entity.User; -import lombok.extern.slf4j.Slf4j; -import org.junit.jupiter.api.DisplayName; -import org.junit.jupiter.api.Tag; -import org.junit.jupiter.api.Test; - -import java.util.ArrayList; -import java.util.Comparator; -import java.util.List; - -@Tag("evaluation") -@Slf4j -public class LambdaOptimizationTest extends RecommendationTestBase { - - @Test - @DisplayName("Lambda 최적화 - 3가지 가중치 조합 (Train/Test Split 방식)") - void optimizeLambdaWithTrainTestSplit() { - log.info("===== Lambda 최적화 테스트 (Train/Test Split) ====="); - log.info("읽은 글 100개 → Train 80개 (프로필 생성용) + Test 20개 (평가용)"); - log.info("가중치 조합: 컨텐츠중심, 요약중심, 기본값"); - log.info("Lambda 범위: 0.0 ~ 1.0 (0.1 단위)"); - - List configs = createLambdaTestConfigs(); - List testUsers = getTestUsers(); - log.info("테스트 사용자: {} 명 (IDs: {})", testUsers.size(), TEST_USER_IDS); - - printImprovedConfigComparisonHeader(); - List results = evaluateAllConfigsWithTrainTestSplit(configs, testUsers); - printBestImprovedResultByWeightType(results); - } - - /** - * Lambda 0.0 ~ 1.0 (0.1 단위) 테스트 설정 생성 - * 컨텐츠 중심 - */ - private List createLambdaTestConfigs() { - List configs = new ArrayList<>(); - - // Lambda 0.0 ~ 1.0 (0.1 단위) - for (int i = 0; i <= 10; i++) { - double lambda = i / 10.0; - - configs.add(ConfigCombo.builder() - .name(String.format("컨텐츠중심 λ=%.1f", lambda)) - .titleWeight(0.2f) - .summaryWeight(0.2f) - .contentWeight(0.6f) - .mmrLambda(lambda) - .build()); - } - - log.info("총 {} 개 설정 생성", configs.size()); - return configs; - } - - /** - * 모든 설정 평가 (Train/Test Split) - */ - private List evaluateAllConfigsWithTrainTestSplit( - List configs, - List testUsers) { - return configs.stream() - .map(config -> { - log.debug("설정 평가 시작 (Train/Test Split): {}", config.getName()); - ImprovedEvaluationResult result = evaluateConfigWithTrainTestSplit(config, testUsers); - log.debug("설정 평가 완료 (Train/Test Split): {} - Recall={}, nDCG={}, ILD={}", - config.getName(), result.getAvgRecall(), result.getAvgNdcg(), result.getAvgIld()); - log.info(result.toString()); - return result; - }) - .toList(); - } - - /** - * 가중치 타입별 최고 성능 설정 출력 (Train/Test Split) - */ - private void printBestImprovedResultByWeightType(List results) { - log.info("\n===== 가중치 타입별 최고 성능 설정 (Train/Test Split) ====="); - - // 컨텐츠 중심 - ImprovedEvaluationResult bestContent = results.stream() - .filter(r -> r.getConfigName().startsWith("컨텐츠중심")) - .max(Comparator.comparingDouble(ImprovedEvaluationResult::getCompositeScore)) - .orElse(null); - if (bestContent != null) { - log.info("\n[컨텐츠 중심 최고]"); - log.info(bestContent.toString()); - } - } -} diff --git a/src/test/java/com/techfork/domain/recommendation/RecommendationConfigComparisonTest.java b/src/test/java/com/techfork/domain/recommendation/RecommendationConfigComparisonTest.java deleted file mode 100644 index 1059321..0000000 --- a/src/test/java/com/techfork/domain/recommendation/RecommendationConfigComparisonTest.java +++ /dev/null @@ -1,82 +0,0 @@ -package com.techfork.domain.recommendation; - -import com.techfork.domain.user.entity.User; -import lombok.extern.slf4j.Slf4j; -import org.junit.jupiter.api.DisplayName; -import org.junit.jupiter.api.Tag; -import org.junit.jupiter.api.Test; - -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; - -/** - * 추천 시스템 설정별 성능 비교 테스트 - */ -@Tag("evaluation") -@Slf4j -public class RecommendationConfigComparisonTest extends RecommendationTestBase { - - @Test - @DisplayName("여러 설정 비교 평가 (Train/Test Split 방식)") - void compareConfigurationsWithTrainTestSplit() { - log.info("===== 설정별 성능 비교 (Train/Test Split) ====="); - log.info("읽은 글 100개 → Train 80개 (프로필 생성용) + Test 20개 (평가용)"); - - List configs = createTestConfigs(); - List testUsers = getTestUsers(); - log.info("테스트 사용자: {} 명 (IDs: {})", testUsers.size(), TEST_USER_IDS); - - printImprovedConfigComparisonHeader(); - List results = evaluateAllConfigsWithTrainTestSplit(configs, testUsers); - } - - /** - * 테스트할 설정 목록 생성 - */ - private List createTestConfigs() { - return Arrays.asList( - ConfigCombo.builder().name("균등 가중치") - .titleWeight(0.33f).summaryWeight(0.33f).contentWeight(0.34f).mmrLambda(0.5).build(), - - ConfigCombo.builder().name("제목 중심") - .titleWeight(0.6f).summaryWeight(0.2f).contentWeight(0.2f).mmrLambda(0.5).build(), - - ConfigCombo.builder().name("요약 중심") - .titleWeight(0.2f).summaryWeight(0.6f).contentWeight(0.2f).mmrLambda(0.5).build(), - - ConfigCombo.builder().name("컨텐츠 중심") - .titleWeight(0.2f).summaryWeight(0.2f).contentWeight(0.6f).mmrLambda(0.5).build(), - - ConfigCombo.builder().name("현재 기본값") - .titleWeight(DEFAULT_TITLE_WEIGHT).summaryWeight(DEFAULT_SUMMARY_WEIGHT) - .contentWeight(DEFAULT_CONTENT_WEIGHT).mmrLambda(DEFAULT_MMR_LAMBDA).build(), - - ConfigCombo.builder().name("제목+요약 중심") - .titleWeight(0.45f).summaryWeight(0.45f).contentWeight(0.1f).mmrLambda(0.5).build(), - - ConfigCombo.builder().name("관련성 중심 (Lambda=0)") - .titleWeight(0.33f).summaryWeight(0.33f).contentWeight(0.34f).mmrLambda(0.0).build(), - - ConfigCombo.builder().name("다양성 중심 (Lambda=1)") - .titleWeight(0.33f).summaryWeight(0.33f).contentWeight(0.34f).mmrLambda(1.0).build() - ); - } - - /** - * 모든 설정 평가 (Train/Test Split) - */ - private List evaluateAllConfigsWithTrainTestSplit( - List configs, - List testUsers) { - List results = new ArrayList<>(); - - for (ConfigCombo config : configs) { - ImprovedEvaluationResult result = evaluateConfigWithTrainTestSplit(config, testUsers); - results.add(result); - log.info(result.toString()); - } - - return results; - } -} diff --git a/src/test/java/com/techfork/domain/recommendation/RecommendationTestBase.java b/src/test/java/com/techfork/domain/recommendation/RecommendationTestBase.java deleted file mode 100644 index c90fe27..0000000 --- a/src/test/java/com/techfork/domain/recommendation/RecommendationTestBase.java +++ /dev/null @@ -1,407 +0,0 @@ -package com.techfork.domain.recommendation; - -import co.elastic.clients.elasticsearch.ElasticsearchClient; -import com.techfork.domain.activity.repository.ReadPostRepository; -import com.techfork.domain.post.repository.PostDocumentRepository; -import com.techfork.domain.post.repository.PostRepository; -import com.techfork.domain.recommendation.config.RecommendationProperties; -import com.techfork.domain.recommendation.repository.RecommendationHistoryRepository; -import com.techfork.domain.recommendation.repository.RecommendedPostRepository; -import com.techfork.domain.recommendation.service.LlmRecommendationService; -import com.techfork.domain.recommendation.service.MmrService; -import com.techfork.domain.recommendation_quality.ImprovedRecommendationTestCase; -import com.techfork.domain.recommendation_quality.RecommendationQualityService; -import com.techfork.domain.recommendation_quality.RecommendationTestCase; -import com.techfork.domain.user.entity.User; -import com.techfork.domain.user.enums.EInterestCategory; -import com.techfork.domain.user.repository.UserProfileDocumentRepository; -import com.techfork.global.elasticsearch.query.VectorQueryBuilder; -import com.techfork.global.util.TimeDecayStrategy; -import com.techfork.global.util.VectorUtil; -import lombok.AllArgsConstructor; -import lombok.Builder; -import lombok.Data; -import lombok.extern.slf4j.Slf4j; -import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.boot.test.context.SpringBootTest; - -import java.util.*; - -/** - * 추천 시스템 테스트를 위한 공통 베이스 클래스 - */ -@Slf4j -@SpringBootTest(properties = "spring.profiles.active=local-tunnel") -public abstract class RecommendationTestBase { - - // 테스트 상수 - protected static final int DEFAULT_K_VALUE = 10; - protected static final int DEFAULT_TEST_USER_COUNT = 5; - protected static final List TEST_USER_IDS = Arrays.asList(71L, 72L, 73L, 74L, 75L); - protected static final float DEFAULT_TITLE_WEIGHT = 0.4f; - protected static final float DEFAULT_SUMMARY_WEIGHT = 0.4f; - protected static final float DEFAULT_CONTENT_WEIGHT = 0.2f; - protected static final double DEFAULT_MMR_LAMBDA = 0.6; - protected static final double RECALL_WEIGHT = 0.4; - protected static final double NDCG_WEIGHT = 0.4; - protected static final double ILD_WEIGHT = 0.2; - - @Autowired - protected TestDataGenerator testDataGenerator; - @Autowired - protected RecommendationQualityService qualityService; - @Autowired - protected PostDocumentRepository postDocumentRepository; - @Autowired - protected ElasticsearchClient elasticsearchClient; - @Autowired - protected UserProfileDocumentRepository userProfileDocumentRepository; - @Autowired - protected RecommendedPostRepository recommendedPostRepository; - @Autowired - protected RecommendationHistoryRepository recommendationHistoryRepository; - @Autowired - protected ReadPostRepository readPostRepository; - @Autowired - protected PostRepository postRepository; - @Autowired - protected TimeDecayStrategy timeDecayStrategy; - @Autowired - protected VectorQueryBuilder vectorQueryBuilder; - @Autowired - protected com.techfork.domain.user.repository.UserRepository userRepository; - - protected static List cachedTestUsers; - - /** - * 설정 조합 - */ - @Data - @Builder - @AllArgsConstructor - protected static class ConfigCombo { - String name; - float titleWeight; - float summaryWeight; - float contentWeight; - double mmrLambda; - } - - /** - * 평가 결과 - */ - @Data - @Builder - @AllArgsConstructor - protected static class EvaluationResult { - String configName; - double avgRecall; - double avgNdcg; - double avgIld; - double compositeScore; - - public double getOverallScore() { - return compositeScore; - } - - @Override - public String toString() { - return String.format("%-20s | Recall: %.4f | nDCG: %.4f | ILD: %.4f | Score: %.4f", - configName, avgRecall, avgNdcg, avgIld, compositeScore); - } - } - - /** - * Train/Test Split 기반 평가 결과 - */ - @Data - @Builder - @AllArgsConstructor - protected static class ImprovedEvaluationResult { - String configName; - double avgRecall; - double avgNdcg; - double avgIld; - double compositeScore; - - @Override - public String toString() { - return String.format("%-20s | Recall: %.4f | nDCG: %.4f | ILD: %.4f | Score: %.4f", - configName, avgRecall, avgNdcg, avgIld, compositeScore); - } - } - - /** - * Train/Test Split 기반 사용자별 평가 메트릭 - */ - @Data - @AllArgsConstructor - protected static class ImprovedUserMetrics { - double recall; - double ndcg; - double ild; - } - - /** - * 사용자별 평가 메트릭 - */ - @Data - @AllArgsConstructor - protected static class UserMetrics { - double recall; - double ndcg; - double ild; - } - - /** - * 테스트 사용자 조회 (ID 71-75번 사용자) - */ - protected List getTestUsers() { - if (cachedTestUsers == null) { - cachedTestUsers = userRepository.findAllWithInterestCategoriesByIds(TEST_USER_IDS); - log.info("테스트 사용자 {} 명 로드 완료: IDs={}", cachedTestUsers.size(), TEST_USER_IDS); - } - return cachedTestUsers; - } - - /** - * 프로필이 있는 테스트 사용자 조회 (레거시, 호환성 유지) - */ - @Deprecated - protected List getTestUsers(int count) { - List users = getTestUsers(); - return users.subList(0, Math.min(count, users.size())); - } - - /** - * 사용자 관심사 추출 - */ - protected List getUserInterests(User user) { - if (user.getInterestCategories() == null || user.getInterestCategories().isEmpty()) { - return List.of(); - } - return user.getInterestCategories().stream() - .map(ic -> ic.getCategory()) - .toList(); - } - - /** - * 복합 점수 계산 - */ - protected double calculateCompositeScore(double recall, double ndcg, double ild) { - return recall * RECALL_WEIGHT + ndcg * NDCG_WEIGHT + ild * ILD_WEIGHT; - } - - /** - * 커스텀 RecommendationProperties 생성 - */ - protected RecommendationProperties createProperties(float titleWeight, float summaryWeight, - float contentWeight, double lambda) { - RecommendationProperties props = new RecommendationProperties(); - props.setKnnSearchSize(100); - props.setNumCandidates(200); - props.setMmrFinalSize(30); - props.setLambda(lambda); - props.setActiveUserHours(24); - - RecommendationProperties.EmbeddingWeights weights = new RecommendationProperties.EmbeddingWeights(); - weights.setTitle(titleWeight); - weights.setSummary(summaryWeight); - weights.setContent(contentWeight); - props.setEmbeddingWeights(weights); - - return props; - } - - /** - * 커스텀 LlmRecommendationService 생성 - */ - protected LlmRecommendationService createRecommendationService(RecommendationProperties props) { - MmrService mmrService = new MmrService(props); - return new LlmRecommendationService( - elasticsearchClient, - userProfileDocumentRepository, - recommendedPostRepository, - recommendationHistoryRepository, - readPostRepository, - postRepository, - mmrService, - timeDecayStrategy, - props, - vectorQueryBuilder - ); - } - - /** - * 게시글 ID 리스트로부터 벡터 조회 - */ - protected List getVectorsForPosts(List postIds) { - return postIds.stream() - .map(postDocumentRepository::findByPostId) - .filter(Optional::isPresent) - .map(Optional::get) - .map(doc -> VectorUtil.convertToFloatArray(doc.getSummaryEmbedding())) - .filter(Objects::nonNull) - .toList(); - } - - /** - * 평균 메트릭 계산 - */ - protected EvaluationResult calculateAverageMetrics(String configName, List metrics) { - double avgRecall = metrics.stream().mapToDouble(UserMetrics::getRecall).average().orElse(0.0); - double avgNdcg = metrics.stream().mapToDouble(UserMetrics::getNdcg).average().orElse(0.0); - double avgIld = metrics.stream().mapToDouble(UserMetrics::getIld).average().orElse(0.0); - double composite = calculateCompositeScore(avgRecall, avgNdcg, avgIld); - - log.debug("설정 평가 완료: {} - Recall={}, nDCG={}, ILD={}", configName, avgRecall, avgNdcg, avgIld); - - return EvaluationResult.builder() - .configName(configName) - .avgRecall(avgRecall) - .avgNdcg(avgNdcg) - .avgIld(avgIld) - .compositeScore(composite) - .build(); - } - - /** - * 설정 비교용 테이블 헤더 출력 - */ - protected void printConfigComparisonHeader() { - log.info("\n%-20s | %-14s | %-14s | %-14s | %-14s", - "설정", "Recall@10", "nDCG@10", "ILD", "Composite"); - log.info("-".repeat(90)); - } - - /** - * Train/Test Split 기반 단일 사용자 평가 - */ - protected Optional evaluateUserWithTrainTestSplit( - User user, - LlmRecommendationService service, - int k) { - try { - List interests = getUserInterests(user); - if (interests.isEmpty()) { - log.debug("사용자 {} 관심사 없음", user.getId()); - return Optional.empty(); - } - - ImprovedRecommendationTestCase testCase = testDataGenerator.generateImprovedTestCase(user, interests); - - if (testCase.getTestPostIds().isEmpty()) { - log.debug("사용자 {} Test Set 없음", user.getId()); - return Optional.empty(); - } - - // Train Set만 제외하고 추천 생성 (Test Set은 추천 후보에 포함) - Set trainPostIdsSet = new java.util.HashSet<>(testCase.getTrainPostIds()); - List recommendedIds = service.generateRecommendationsForEvaluation(user, trainPostIdsSet); - - if (recommendedIds.isEmpty()) { - log.debug("사용자 {} 추천 결과 없음", user.getId()); - return Optional.empty(); - } - - // Test Set의 각 글에 관련도 5점 부여 (실제로 읽은 글 = 매우 관련 있음) - Map relevanceScores = testCase.getTestPostIds().stream() - .collect(java.util.stream.Collectors.toMap( - postId -> postId, - postId -> 5 - )); - - double recall = qualityService.calculateRecall(recommendedIds, testCase.getGroundTruthPostIds(), k); - double ndcg = qualityService.calculateNDCG(recommendedIds, relevanceScores, k); - List vectors = getVectorsForPosts(recommendedIds.stream().limit(k).toList()); - double ild = qualityService.calculateILD(vectors); - - // 디버깅: 추천된 글 중 Test Set과 겹치는지 확인 - List topK = recommendedIds.stream().limit(k).toList(); - long matchCount = topK.stream() - .filter(testCase.getGroundTruthPostIds()::contains) - .count(); - - log.info("===== 사용자 {} 평가 상세 =====", user.getId()); - log.info("Train Set: {} 개", testCase.getTrainPostIds().size()); - log.info("Test Set: {} 개 (Ground Truth)", testCase.getTestPostIds().size()); - log.info("추천된 글: {} 개", recommendedIds.size()); - log.info("Top-{} 중 Test Set 포함: {} 개", k, matchCount); - log.info("Recall@{}: {:.4f}", k, recall); - log.info("nDCG@{}: {:.4f}", k, ndcg); - log.info("ILD: {:.4f}", ild); - - if (matchCount == 0) { - log.warn("⚠️ Top-{}에 Test Set이 하나도 없습니다!", k); - log.warn("Test Set ID 샘플: {}", testCase.getTestPostIds().stream().limit(5).toList()); - log.warn("추천 ID 샘플: {}", topK.stream().limit(5).toList()); - } - - return Optional.of(new ImprovedUserMetrics(recall, ndcg, ild)); - - } catch (Exception e) { - log.warn("사용자 {} 평가 중 오류", user.getId(), e); - return Optional.empty(); - } - } - - /** - * Train/Test Split 기반 설정 평가 - */ - protected ImprovedEvaluationResult evaluateConfigWithTrainTestSplit( - ConfigCombo config, - List testUsers) { - log.debug("설정 평가 시작 (Train/Test Split): {}", config.getName()); - - LlmRecommendationService service = createRecommendationService( - createProperties( - config.getTitleWeight(), - config.getSummaryWeight(), - config.getContentWeight(), - config.getMmrLambda() - ) - ); - - List metrics = testUsers.stream() - .map(user -> evaluateUserWithTrainTestSplit(user, service, DEFAULT_K_VALUE)) - .filter(Optional::isPresent) - .map(Optional::get) - .toList(); - - return calculateAverageImprovedMetrics(config.getName(), metrics); - } - - /** - * Train/Test Split 기반 평균 메트릭 계산 - */ - protected ImprovedEvaluationResult calculateAverageImprovedMetrics( - String configName, - List metrics) { - double avgRecall = metrics.stream().mapToDouble(ImprovedUserMetrics::getRecall).average().orElse(0.0); - double avgNdcg = metrics.stream().mapToDouble(ImprovedUserMetrics::getNdcg).average().orElse(0.0); - double avgIld = metrics.stream().mapToDouble(ImprovedUserMetrics::getIld).average().orElse(0.0); - - // 복합 점수: Recall 40%, nDCG 40%, ILD 20% - double composite = avgRecall * RECALL_WEIGHT + avgNdcg * NDCG_WEIGHT + avgIld * ILD_WEIGHT; - - log.debug("설정 평가 완료 (Train/Test Split): {} - Recall={}, nDCG={}, ILD={}", - configName, avgRecall, avgNdcg, avgIld); - - return ImprovedEvaluationResult.builder() - .configName(configName) - .avgRecall(avgRecall) - .avgNdcg(avgNdcg) - .avgIld(avgIld) - .compositeScore(composite) - .build(); - } - - /** - * Train/Test Split 기반 설정 비교 테이블 헤더 - */ - protected void printImprovedConfigComparisonHeader() { - log.info("\n%-20s | %-14s | %-14s | %-14s | %-14s", - "설정", "Recall@10", "nDCG@10", "ILD", "Composite"); - log.info("-".repeat(90)); - } -} diff --git a/src/test/java/com/techfork/domain/recommendation/RecommendationTestDataSetup.java b/src/test/java/com/techfork/domain/recommendation/RecommendationTestDataSetup.java deleted file mode 100644 index 080b6ef..0000000 --- a/src/test/java/com/techfork/domain/recommendation/RecommendationTestDataSetup.java +++ /dev/null @@ -1,43 +0,0 @@ -package com.techfork.domain.recommendation; - -import com.techfork.domain.user.entity.User; -import com.techfork.domain.user.enums.EInterestCategory; -import lombok.extern.slf4j.Slf4j; -import org.junit.jupiter.api.Disabled; -import org.junit.jupiter.api.DisplayName; -import org.junit.jupiter.api.Tag; -import org.junit.jupiter.api.Test; - -import java.util.Arrays; -import java.util.List; - -/** - * 추천 시스템 평가를 위한 테스트 데이터 생성 - */ -@Tag("evaluation-setup") -@Disabled("데이터 셋업용 - CI 제외") -@Slf4j -public class RecommendationTestDataSetup extends RecommendationTestBase { - - @Test - @DisplayName("테스트 데이터 생성 (5명)") - void generateTestData() { - log.info("===== 테스트 데이터 생성 ====="); - - List> interestCombos = Arrays.asList( - Arrays.asList(EInterestCategory.BACKEND), - Arrays.asList(EInterestCategory.FRONTEND), - Arrays.asList(EInterestCategory.AI_ML), - Arrays.asList(EInterestCategory.BACKEND, EInterestCategory.DATABASE), - Arrays.asList(EInterestCategory.AI_ML, EInterestCategory.DATA_SCIENCE) - ); - - for (int i = 0; i < 5; i++) { - List interests = interestCombos.get(i % interestCombos.size()); - User user = testDataGenerator.createTestUser(interests, 100); - log.info("사용자 생성 완료: ID={}, 관심사={}", user.getId(), interests); - } - - log.info("5명 사용자 생성 완료"); - } -} diff --git a/src/test/java/com/techfork/domain/recommendation/TestDataGenerator.java b/src/test/java/com/techfork/domain/recommendation/TestDataGenerator.java deleted file mode 100644 index 0097862..0000000 --- a/src/test/java/com/techfork/domain/recommendation/TestDataGenerator.java +++ /dev/null @@ -1,357 +0,0 @@ -package com.techfork.domain.recommendation; - -import com.techfork.domain.activity.entity.ReadPost; -import com.techfork.domain.activity.repository.ReadPostRepository; -import com.techfork.domain.post.entity.Post; -import com.techfork.domain.post.repository.PostRepository; -import com.techfork.domain.recommendation_quality.ImprovedRecommendationTestCase; -import com.techfork.domain.recommendation_quality.RecommendationTestCase; -import com.techfork.domain.recommendation_quality.TrainTestSplit; -import com.techfork.domain.user.entity.User; -import com.techfork.domain.user.entity.UserInterestCategory; -import com.techfork.domain.user.entity.UserInterestKeyword; -import com.techfork.domain.user.enums.EInterestCategory; -import com.techfork.domain.user.enums.EInterestKeyword; -import com.techfork.domain.user.enums.SocialType; -import com.techfork.domain.user.repository.UserInterestCategoryRepository; -import com.techfork.domain.user.repository.UserProfileDocumentRepository; -import com.techfork.domain.user.repository.UserRepository; -import com.techfork.domain.user.service.UserProfileService; -import jakarta.persistence.EntityManager; -import lombok.RequiredArgsConstructor; -import lombok.extern.slf4j.Slf4j; -import org.junit.jupiter.api.Disabled; -import org.junit.jupiter.api.Tag; -import org.springframework.data.domain.PageRequest; -import org.springframework.stereotype.Component; -import org.springframework.transaction.annotation.Transactional; - -import java.time.LocalDateTime; -import java.util.*; - -/** - * 추천 시스템 테스트를 위한 데이터 생성기 - * DB의 실제 게시글 데이터를 기반으로 테스트용 사용자 프로필과 Ground Truth 생성 - */ -@Tag("evaluation-setup") -@Disabled("데이터 셋업용 - CI 제외") -@Slf4j -@Component -@RequiredArgsConstructor -public class TestDataGenerator { - - private final PostRepository postRepository; - private final UserRepository userRepository; - private final UserInterestCategoryRepository userInterestCategoryRepository; - private final ReadPostRepository readPostRepository; - private final UserProfileService userProfileService; - private final UserProfileDocumentRepository userProfileDocumentRepository; - private final EntityManager entityManager; - - /** - * 테스트용 사용자 생성 및 읽은 글 이력 추가 - * - * @param interestCategories 관심사 카테고리 목록 - * @param readPostCount 읽은 글 개수 - * @return 생성된 사용자 - */ - @Transactional - public User createTestUser(List interestCategories, int readPostCount) { - // 사용자 생성 - User user = User.createSocialUser( - SocialType.KAKAO, - "testSocialId_" + UUID.randomUUID().toString(), - "test_" + System.currentTimeMillis() + "@example.com", - null - ); - user = userRepository.save(user); - - log.info("테스트 사용자 생성: ID: {}", user.getId()); - - // 관심사 카테고리 및 키워드 추가 - for (EInterestCategory category : interestCategories) { - UserInterestCategory interestCategory = UserInterestCategory.create(user, category); - userInterestCategoryRepository.save(interestCategory); - - // 해당 카테고리의 키워드 중 랜덤하게 2~4개 선택 - List availableKeywords = EInterestKeyword.getKeywordsByCategory(category); - Collections.shuffle(availableKeywords); - int keywordCount = 2 + (int) (Math.random() * 3); // 2~4개 - - for (int i = 0; i < Math.min(keywordCount, availableKeywords.size()); i++) { - UserInterestKeyword keyword = UserInterestKeyword.create( - interestCategory, - availableKeywords.get(i) - ); - interestCategory.addKeyword(keyword); - } - - userInterestCategoryRepository.save(interestCategory); - } - log.info("관심사 추가: {} (각 카테고리별 키워드 포함)", interestCategories); - - // 읽은 글 이력 생성 (관심사와 관련된 게시글 위주) - Batch Insert - List posts = findPostsRelatedToInterests(interestCategories, readPostCount); - - LocalDateTime now = LocalDateTime.now(); - int batchSize = 20; - for (int i = 0; i < Math.min(posts.size(), readPostCount); i++) { - Post post = posts.get(i); - ReadPost readPost = ReadPost.create( - user, - post, - now.minusDays(readPostCount - i), - 180 // 3분 읽음 - ); - entityManager.persist(readPost); - - // Batch Insert를 위한 flush & clear - if ((i + 1) % batchSize == 0) { - entityManager.flush(); - entityManager.clear(); - } - } - entityManager.flush(); - entityManager.clear(); - - log.info("읽은 글 이력 생성: {} 개 (Batch Insert)", Math.min(posts.size(), readPostCount)); - - // UserProfile 생성 (임베딩 포함) - 동기 버전 사용 - try { - userProfileService.generateUserProfileSync(user.getId()); - log.info("사용자 프로필 및 임베딩 생성 완료: userId={}", user.getId()); - } catch (Exception e) { - log.error("사용자 프로필 생성 실패: userId={}", user.getId(), e); - throw e; - } - - return user; - } - - /** - * 관심사와 관련된 게시글 찾기 (회사명 또는 제목 기반) - */ - @Transactional(readOnly = true) - public List findPostsRelatedToInterests(List interests, int limit) { - // 관심사 카테고리별 키워드 매핑 - Map> interestKeywords = new HashMap<>(); - interestKeywords.put(EInterestCategory.BACKEND, Arrays.asList("Spring", "Java", "Kotlin", "API", "서버", "Backend")); - interestKeywords.put(EInterestCategory.AI_ML, Arrays.asList("AI", "ML", "머신러닝", "딥러닝", "LLM", "GPT")); - interestKeywords.put(EInterestCategory.FRONTEND, Arrays.asList("React", "Vue", "JavaScript", "CSS", "UI", "Frontend")); - interestKeywords.put(EInterestCategory.DATA_ENGINEERING, Arrays.asList("데이터", "분석", "Spark", "Kafka")); - interestKeywords.put(EInterestCategory.DATA_SCIENCE, Arrays.asList("데이터", "분석", "ML", "통계")); - interestKeywords.put(EInterestCategory.DATABASE, Arrays.asList("SQL", "Database", "MySQL", "PostgreSQL")); - interestKeywords.put(EInterestCategory.DEVOPS, Arrays.asList("DevOps", "Docker", "Kubernetes", "CI/CD")); - interestKeywords.put(EInterestCategory.CLOUD, Arrays.asList("AWS", "클라우드", "Cloud", "Azure")); - - List allPosts = postRepository.findAll(PageRequest.of(0, 500)).getContent(); - - // 관심사 키워드와 매칭되는 게시글 찾기 - List relatedPosts = allPosts.stream() - .filter(post -> { - String title = post.getTitle().toLowerCase(); - return interests.stream() - .flatMap(interest -> interestKeywords.getOrDefault(interest, Collections.emptyList()).stream()) - .anyMatch(keyword -> title.contains(keyword.toLowerCase())); - }) - .limit(limit) - .toList(); - - // 매칭 안 되면 랜덤으로 선택 - if (relatedPosts.size() < limit && !allPosts.isEmpty()) { - List remaining = new ArrayList<>(allPosts); - remaining.removeAll(relatedPosts); - Collections.shuffle(remaining); - - List combined = new ArrayList<>(relatedPosts); - combined.addAll(remaining.subList(0, Math.min(limit - relatedPosts.size(), remaining.size()))); - return combined; - } - - return relatedPosts; - } - - /** - * 테스트 케이스 생성 - * 사용자의 관심사와 읽은 글을 기반으로 관련도 점수 부여 - */ - @Transactional(readOnly = true) - public RecommendationTestCase generateTestCase(User user, List interests) { - List readPostIds = readPostRepository.findRecentReadPostsByUserIdWithMinDuration(user.getId(), PageRequest.of(0, 1000)).stream() - .map(rp -> rp.getPost().getId()) - .toList(); - - // 관련 있는 게시글 찾기 및 점수 부여 - List candidatePosts = findPostsRelatedToInterests(interests, 100); - Map relevanceScores = new HashMap<>(); - - for (Post post : candidatePosts) { - // 이미 읽은 글은 제외 - if (readPostIds.contains(post.getId())) { - continue; - } - - // 관련도 점수 계산 (간단한 키워드 매칭 기반) - int score = calculateRelevanceScore(post, interests); - if (score > 0) { - relevanceScores.put(post.getId(), score); - } - } - - List interestNames = interests.stream() - .map(EInterestCategory::getDisplayName) - .toList(); - - return RecommendationTestCase.builder() - .userId(user.getId()) - .interests(interestNames) - .readPostIds(readPostIds) - .groundTruthScores(relevanceScores) - .build(); - } - - - /** - * 프로필이 있는 기존 사용자 조회 (테스트용) - * - * @param count 조회할 사용자 수 - * @return 프로필이 있는 사용자 리스트 - * @throws IllegalStateException 프로필이 있는 사용자가 충분하지 않은 경우 - */ - @Transactional(readOnly = true) - public List getUsersWithProfile(int count) { - // 1. 모든 사용자 중 프로필이 있는 사용자 찾기 - List allUsers = userRepository.findAll(); - List usersWithProfile = allUsers.stream() - .filter(user -> userProfileDocumentRepository.findByUserId(user.getId()).isPresent()) - .filter(user -> user.getInterestCategories() != null && !user.getInterestCategories().isEmpty()) - .limit(count) - .toList(); - - // 2. 충분한 사용자가 있는지 확인 - if (usersWithProfile.size() < count) { - throw new IllegalStateException( - String.format("프로필이 있는 사용자가 부족합니다. (현재 %d명, 필요 %d명). " + - "먼저 createTestUser()로 테스트 사용자를 생성하세요.", - usersWithProfile.size(), count) - ); - } - - log.info("프로필이 있는 기존 사용자 {} 명 조회 완료", usersWithProfile.size()); - List userIds = usersWithProfile.stream().map(User::getId).toList(); - return userRepository.findAllWithInterestCategoriesByIds(userIds); - } - - /** - * 게시글의 관련도 점수 계산 (1~5점) - */ - private int calculateRelevanceScore(Post post, List interests) { - String title = post.getTitle().toLowerCase(); - - // 관심사별 키워드 매칭 개수 - int matchCount = 0; - for (EInterestCategory interest : interests) { - if (title.contains(interest.getDisplayName().toLowerCase())) { - matchCount++; - } - } - - // 매칭 개수에 따라 점수 부여 - if (matchCount >= 3) return 5; - if (matchCount == 2) return 4; - if (matchCount == 1) return 3; - - // 회사명이 유명 기업이면 기본 2점 - String company = post.getTechBlog().getCompanyName().toLowerCase(); - if (company.contains("네이버") || company.contains("kakao") || - company.contains("line") || company.contains("쿠팡")) { - return 2; - } - - return 1; // 기본 점수 - } - - /** - * 읽은 글 이력을 Train/Test로 분할 (8:2 비율) - * - * @param readPostIds 전체 읽은 글 ID 목록 - * @param trainRatio Train 세트 비율 (기본 0.8) - * @return Train/Test 분할 결과 - */ - public TrainTestSplit splitReadHistory(List readPostIds, double trainRatio) { - if (readPostIds == null || readPostIds.isEmpty()) { - return TrainTestSplit.builder() - .trainPostIds(Collections.emptyList()) - .testPostIds(Collections.emptyList()) - .build(); - } - - // 시간순으로 정렬된 리스트를 Train/Test로 분할 - int totalSize = readPostIds.size(); - int trainSize = (int) (totalSize * trainRatio); - - List trainIds = readPostIds.subList(0, trainSize); - List testIds = readPostIds.subList(trainSize, totalSize); - - log.info("Train/Test Split 완료: Train={}, Test={}, 비율={:.2f}", - trainIds.size(), testIds.size(), trainRatio); - - return TrainTestSplit.builder() - .trainPostIds(new ArrayList<>(trainIds)) - .testPostIds(new ArrayList<>(testIds)) - .build(); - } - - /** - * Train/Test Split 기반 개선된 테스트 케이스 생성 - * - * @param user 평가 대상 사용자 - * @param interests 사용자 관심사 - * @param trainRatio Train 세트 비율 (기본 0.8) - * @return Train/Test Split 기반 테스트 케이스 - */ - @Transactional(readOnly = true) - public ImprovedRecommendationTestCase generateImprovedTestCase( - User user, - List interests, - double trainRatio) { - - // 읽은 글 이력 조회 (시간순) - List readPostIds = readPostRepository - .findRecentReadPostsByUserIdWithMinDuration(user.getId(), PageRequest.of(0, 1000)) - .stream() - .map(rp -> rp.getPost().getId()) - .toList(); - - // Train/Test Split - TrainTestSplit split = splitReadHistory(readPostIds, trainRatio); - - List interestNames = interests.stream() - .map(EInterestCategory::getDisplayName) - .toList(); - - log.info("===== 개선된 테스트 케이스 생성 ====="); - log.info("사용자 ID: {}", user.getId()); - log.info("관심사: {}", interestNames); - log.info("전체 읽은 글: {} 개", readPostIds.size()); - log.info("Train Set: {} 개", split.getTrainSize()); - log.info("Test Set: {} 개", split.getTestSize()); - log.info("Test Set ID 샘플: {}", split.getTestPostIds().stream().limit(5).toList()); - - return ImprovedRecommendationTestCase.builder() - .userId(user.getId()) - .interests(interestNames) - .trainTestSplit(split) - .build(); - } - - /** - * Train/Test Split 기반 개선된 테스트 케이스 생성 (기본 비율 0.8) - */ - @Transactional(readOnly = true) - public ImprovedRecommendationTestCase generateImprovedTestCase( - User user, - List interests) { - return generateImprovedTestCase(user, interests, 0.8); - } -} diff --git a/src/test/java/com/techfork/domain/recommendation/evaluation/LambdaOptimizationTest.java b/src/test/java/com/techfork/domain/recommendation/evaluation/LambdaOptimizationTest.java new file mode 100644 index 0000000..8f096a2 --- /dev/null +++ b/src/test/java/com/techfork/domain/recommendation/evaluation/LambdaOptimizationTest.java @@ -0,0 +1,86 @@ +package com.techfork.domain.recommendation.evaluation; + +import com.techfork.domain.user.entity.User; +import lombok.extern.slf4j.Slf4j; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Tag; +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; + +/** + * MMR Lambda 파라미터 최적화 테스트 + */ +@Tag("evaluation") +@Slf4j +public class LambdaOptimizationTest extends RecommendationTestBase { + + @Test + @DisplayName("Lambda 최적화 - Ground-Truth 기반 평가") + void optimizeLambdaWithGroundTruth() { + log.info("===== Lambda 최적화 테스트 (Ground-Truth 기반) ====="); + + if (cachedGroundTruth == null || cachedGroundTruth.isEmpty()) { + log.warn("Ground-Truth 데이터가 없습니다. Fixture 로드를 확인하세요."); + return; + } + + log.info("가중치 고정: 제목(0.5) + 요약(0.5)"); + log.info("Lambda 범위: 0.0 ~ 1.0 (0.1 단위)"); + + List configs = createLambdaTestConfigs(); + List testUsers = getTestUsers(); + log.info("테스트 사용자: {} 명", testUsers.size()); + + printLambdaOptimizationHeader(); + List results = configs.stream() + .map(config -> { + EvaluationResult result = evaluateConfigWithGroundTruthAndILD(config, testUsers); + printLambdaOptimizationResult(result); + return result; + }) + .toList(); + + printBestLambdaResults(results); + } + + private List createLambdaTestConfigs() { + List configs = new ArrayList<>(); + // Lambda 0.0 ~ 1.0 (0.1 단위) + for (int i = 0; i <= 10; i++) { + double lambda = i / 10.0; + configs.add(ConfigCombo.builder() + .name(String.format("T0.5/S0.5 λ=%.1f", lambda)) + .titleWeight(0.5f) + .summaryWeight(0.5f) + .contentWeight(0.0f) + .mmrLambda(lambda) + .build()); + } + return configs; + } + + private void printBestLambdaResults(List results) { + log.info("\n===== Lambda 최적화 결과 요약 (K=8 기준) ====="); + + // 복합 점수 최고 + results.stream() + .max(Comparator.comparingDouble(EvaluationResult::getCompositeScore)) + .ifPresent(best -> log.info(String.format("[복합 점수 최고] %s | R@8: %.4f | nDCG@8: %.4f | ILD: %.4f | Score: %.4f", + best.getConfigName(), best.getAvgRecall8(), best.getAvgNdcg8(), best.getAvgIld(), best.getCompositeScore()))); + + // 다양성(ILD) 최고 + results.stream() + .max(Comparator.comparingDouble(EvaluationResult::getAvgIld)) + .ifPresent(best -> log.info(String.format("[다양성(ILD) 최고] %s | ILD: %.4f", + best.getConfigName(), best.getAvgIld()))); + + // Recall@8 최고 + results.stream() + .max(Comparator.comparingDouble(EvaluationResult::getAvgRecall8)) + .ifPresent(best -> log.info(String.format("[Recall@8 최고] %s | R@8: %.4f", + best.getConfigName(), best.getAvgRecall8()))); + } +} diff --git a/src/test/java/com/techfork/domain/recommendation/evaluation/RecommendationConfigComparisonTest.java b/src/test/java/com/techfork/domain/recommendation/evaluation/RecommendationConfigComparisonTest.java new file mode 100644 index 0000000..6016a2d --- /dev/null +++ b/src/test/java/com/techfork/domain/recommendation/evaluation/RecommendationConfigComparisonTest.java @@ -0,0 +1,129 @@ +package com.techfork.domain.recommendation.evaluation; + +import com.techfork.domain.user.entity.User; +import lombok.extern.slf4j.Slf4j; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Tag; +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Comparator; +import java.util.List; + +/** + * 추천 시스템 설정별 성능 비교 테스트 + */ +@Tag("evaluation") +@Slf4j +public class RecommendationConfigComparisonTest extends RecommendationTestBase { + + @Test + @DisplayName("여러 설정 비교 평가 (Ground-Truth 기반)") + void compareConfigurationsWithGroundTruth() { + log.info("===== 설정별 성능 비교 (Ground-Truth 기반) ====="); + log.info("Ground-Truth: {} 명 사용자", cachedGroundTruth.size()); + + List configs = createTestConfigs(); + List testUsers = getTestUsers(); + log.info("테스트 사용자: {} 명", testUsers.size()); + + printWeightComparisonHeader(); + List results = evaluateAllConfigsWithGroundTruth(configs, testUsers); + printBestWeightResult(results); + } + + /** + * 테스트할 설정 목록 생성 + * 가중치 조합 테스트이므로 Lambda=1.0 (관련성 100%, 다양성 제외) + */ + private List createTestConfigs() { + return Arrays.asList( + ConfigCombo.builder().name("균등 가중치") + .titleWeight(0.33f).summaryWeight(0.33f).contentWeight(0.34f).mmrLambda(1.0).build(), + + ConfigCombo.builder().name("제목 중심") + .titleWeight(0.6f).summaryWeight(0.2f).contentWeight(0.2f).mmrLambda(1.0).build(), + + ConfigCombo.builder().name("요약 중심") + .titleWeight(0.2f).summaryWeight(0.6f).contentWeight(0.2f).mmrLambda(1.0).build(), + + ConfigCombo.builder().name("컨텐츠 중심") + .titleWeight(0.2f).summaryWeight(0.2f).contentWeight(0.6f).mmrLambda(1.0).build(), + + ConfigCombo.builder().name("현재 기본값") + .titleWeight(DEFAULT_TITLE_WEIGHT).summaryWeight(DEFAULT_SUMMARY_WEIGHT) + .contentWeight(DEFAULT_CONTENT_WEIGHT).mmrLambda(1.0).build(), + + ConfigCombo.builder().name("제목+요약 중심") + .titleWeight(0.45f).summaryWeight(0.45f).contentWeight(0.1f).mmrLambda(1.0).build(), + + ConfigCombo.builder().name("제목+요약만 (컨텐츠 0)") + .titleWeight(0.5f).summaryWeight(0.5f).contentWeight(0.0f).mmrLambda(1.0).build() + ); + } + + /** + * 모든 설정 평가 (Ground-Truth 기반 - ILD 제외) + */ + private List evaluateAllConfigsWithGroundTruth( + List configs, + List testUsers) { + List results = new ArrayList<>(); + + for (ConfigCombo config : configs) { + EvaluationResult result = evaluateConfigWithGroundTruth(config, testUsers); + results.add(result); + printWeightComparisonResult(result); + } + + return results; + } + + /** + * 최고 성능 가중치 조합 출력 (K별 Recall, nDCG 기준) + */ + private void printBestWeightResult(List results) { + log.info("\n===== 최고 성능 가중치 조합 (K=8 첫 화면 기준) ====="); + + // Recall@8 최고 + EvaluationResult bestRecall8 = results.stream() + .max(Comparator.comparingDouble(EvaluationResult::getAvgRecall8)) + .orElse(null); + + if (bestRecall8 != null) { + log.info("\n[Recall@8 최고]"); + log.info(String.format("%-25s | R@4: %.4f | R@8: %.4f | R@30: %.4f | nDCG@4: %.4f | nDCG@8: %.4f | nDCG@30: %.4f", + bestRecall8.getConfigName(), + bestRecall8.getAvgRecall4(), bestRecall8.getAvgRecall8(), bestRecall8.getAvgRecall30(), + bestRecall8.getAvgNdcg4(), bestRecall8.getAvgNdcg8(), bestRecall8.getAvgNdcg30())); + } + + // nDCG@8 최고 + EvaluationResult bestNdcg8 = results.stream() + .max(Comparator.comparingDouble(EvaluationResult::getAvgNdcg8)) + .orElse(null); + + if (bestNdcg8 != null) { + log.info("\n[nDCG@8 최고]"); + log.info(String.format("%-25s | R@4: %.4f | R@8: %.4f | R@30: %.4f | nDCG@4: %.4f | nDCG@8: %.4f | nDCG@30: %.4f", + bestNdcg8.getConfigName(), + bestNdcg8.getAvgRecall4(), bestNdcg8.getAvgRecall8(), bestNdcg8.getAvgRecall30(), + bestNdcg8.getAvgNdcg4(), bestNdcg8.getAvgNdcg8(), bestNdcg8.getAvgNdcg30())); + } + + // 균형잡힌 설정 (Recall@8 + nDCG@8 평균) + EvaluationResult bestBalanced = results.stream() + .max(Comparator.comparingDouble(r -> (r.getAvgRecall8() + r.getAvgNdcg8()) / 2.0)) + .orElse(null); + + if (bestBalanced != null) { + double balancedScore = (bestBalanced.getAvgRecall8() + bestBalanced.getAvgNdcg8()) / 2.0; + log.info("\n[균형 점수 최고 (R@8 + nDCG@8 평균: {:.4f})]", balancedScore); + log.info(String.format("%-25s | R@4: %.4f | R@8: %.4f | R@30: %.4f | nDCG@4: %.4f | nDCG@8: %.4f | nDCG@30: %.4f", + bestBalanced.getConfigName(), + bestBalanced.getAvgRecall4(), bestBalanced.getAvgRecall8(), bestBalanced.getAvgRecall30(), + bestBalanced.getAvgNdcg4(), bestBalanced.getAvgNdcg8(), bestBalanced.getAvgNdcg30())); + } + } +} diff --git a/src/main/java/com/techfork/domain/recommendation_quality/RecommendationQualityService.java b/src/test/java/com/techfork/domain/recommendation/evaluation/RecommendationQualityService.java similarity index 94% rename from src/main/java/com/techfork/domain/recommendation_quality/RecommendationQualityService.java rename to src/test/java/com/techfork/domain/recommendation/evaluation/RecommendationQualityService.java index e663728..4d37947 100644 --- a/src/main/java/com/techfork/domain/recommendation_quality/RecommendationQualityService.java +++ b/src/test/java/com/techfork/domain/recommendation/evaluation/RecommendationQualityService.java @@ -1,139 +1,139 @@ -package com.techfork.domain.recommendation_quality; - -import com.techfork.global.util.VectorUtil; -import org.springframework.stereotype.Service; - -import java.util.*; - -/** - * 추천 시스템 품질 평가 서비스 - * - Recall@K: 정답 아이템 재현율 - * - nDCG@K: 순위 기반 정확도 - * - ILD (Intra-List Diversity): 추천 목록 내 다양성 - */ -@Service -public class RecommendationQualityService { - - /** - * Recall@K 계산 - * 수식: (Top K 결과 중 정답 문서 개수) / (전체 정답 문서 개수) - * - * @param recommendedIds 추천된 게시글 ID 리스트 (순서대로) - * @param relevantIds 실제 관련있는(정답) 게시글 ID 집합 - * @param k 평가할 상위 개수 (예: 10) - */ - public double calculateRecall(List recommendedIds, Set relevantIds, int k) { - if (relevantIds == null || relevantIds.isEmpty()) { - return 0.0; - } - - List topKResults = recommendedIds.stream() - .limit(k) - .toList(); - - long hitCount = topKResults.stream() - .filter(relevantIds::contains) - .count(); - - return (double) hitCount / relevantIds.size(); - } - - /** - * nDCG@K 계산 - * 수식: DCG@K / IDCG@K - * DCG는 상위 결과일수록 가중치를 높게 부여 - * - * @param recommendedIds 추천된 게시글 ID 리스트 - * @param relevanceScores 게시글 ID별 관련도 점수 (Map) - * 점수가 높을수록 더 관련있음 (예: 1~5점) - * @param k 평가할 상위 개수 (예: 10) - */ - public double calculateNDCG(List recommendedIds, Map relevanceScores, int k) { - double dcg = calculateDCG(recommendedIds, relevanceScores, k); - double idcg = calculateIDCG(relevanceScores, k); - - if (idcg == 0.0) return 0.0; - - return dcg / idcg; - } - - /** - * ILD (Intra-List Diversity) 계산 - * 추천 목록 내 아이템 간 평균 비유사도 (1 - 평균 코사인 유사도) - * 값이 클수록 다양한 추천 - * - * @param itemVectors 추천된 아이템들의 벡터 리스트 (순서대로) - * @return 0.0 ~ 1.0 사이 값 (1.0에 가까울수록 다양함) - */ - public double calculateILD(List itemVectors) { - if (itemVectors == null || itemVectors.size() <= 1) { - return 0.0; - } - - List validVectors = itemVectors.stream() - .filter(Objects::nonNull) - .filter(v -> v.length > 0) - .toList(); - - if (validVectors.size() <= 1) { - return 0.0; - } - - double totalSimilarity = 0.0; - int pairCount = 0; - - // 모든 아이템 쌍의 유사도 계산 - for (int i = 0; i < validVectors.size(); i++) { - for (int j = i + 1; j < validVectors.size(); j++) { - double similarity = VectorUtil.cosineSimilarity(validVectors.get(i), validVectors.get(j)); - totalSimilarity += similarity; - pairCount++; - } - } - - if (pairCount == 0) { - return 0.0; - } - - double avgSimilarity = totalSimilarity / pairCount; - return 1.0 - avgSimilarity; // 비유사도 = 1 - 유사도 - } - - /** - * DCG (Discounted Cumulative Gain) 계산 - */ - private double calculateDCG(List recommendedIds, Map relevanceScores, int k) { - double dcg = 0.0; - - for (int i = 0; i < Math.min(recommendedIds.size(), k); i++) { - Long postId = recommendedIds.get(i); - int relevance = relevanceScores.getOrDefault(postId, 0); - - if (relevance > 0) { - // DCG 공식: rel_i / log2(i + 2) - dcg += relevance / (Math.log(i + 2) / Math.log(2)); - } - } - return dcg; - } - - /** - * IDCG (Ideal DCG) 계산 - * 이상적인 순서(관련도가 높은 순)로 정렬했을 때의 DCG - */ - private double calculateIDCG(Map relevanceScores, int k) { - List idealRelevances = relevanceScores.values().stream() - .sorted(Comparator.reverseOrder()) - .limit(k) - .toList(); - - double idcg = 0.0; - for (int i = 0; i < idealRelevances.size(); i++) { - int relevance = idealRelevances.get(i); - if (relevance > 0) { - idcg += relevance / (Math.log(i + 2) / Math.log(2)); - } - } - return idcg; - } -} +package com.techfork.domain.recommendation.evaluation; + +import com.techfork.global.util.VectorUtil; +import org.springframework.stereotype.Component; + +import java.util.*; + +/** + * 추천 시스템 품질 평가 서비스 + * - Recall@K: 정답 아이템 재현율 + * - nDCG@K: 순위 기반 정확도 + * - ILD (Intra-List Diversity): 추천 목록 내 다양성 + */ +@Component +public class RecommendationQualityService { + + /** + * Recall@K 계산 + * 수식: (Top K 결과 중 정답 문서 개수) / (전체 정답 문서 개수) + * + * @param recommendedIds 추천된 게시글 ID 리스트 (순서대로) + * @param relevantIds 실제 관련있는(정답) 게시글 ID 집합 + * @param k 평가할 상위 개수 (예: 10) + */ + public double calculateRecall(List recommendedIds, Set relevantIds, int k) { + if (relevantIds == null || relevantIds.isEmpty()) { + return 0.0; + } + + List topKResults = recommendedIds.stream() + .limit(k) + .toList(); + + long hitCount = topKResults.stream() + .filter(relevantIds::contains) + .count(); + + return (double) hitCount / relevantIds.size(); + } + + /** + * nDCG@K 계산 + * 수식: DCG@K / IDCG@K + * DCG는 상위 결과일수록 가중치를 높게 부여 + * + * @param recommendedIds 추천된 게시글 ID 리스트 + * @param relevanceScores 게시글 ID별 관련도 점수 (Map) + * 점수가 높을수록 더 관련있음 (예: 1~5점) + * @param k 평가할 상위 개수 (예: 10) + */ + public double calculateNDCG(List recommendedIds, Map relevanceScores, int k) { + double dcg = calculateDCG(recommendedIds, relevanceScores, k); + double idcg = calculateIDCG(relevanceScores, k); + + if (idcg == 0.0) return 0.0; + + return dcg / idcg; + } + + /** + * ILD (Intra-List Diversity) 계산 + * 추천 목록 내 아이템 간 평균 비유사도 (1 - 평균 코사인 유사도) + * 값이 클수록 다양한 추천 + * + * @param itemVectors 추천된 아이템들의 벡터 리스트 (순서대로) + * @return 0.0 ~ 1.0 사이 값 (1.0에 가까울수록 다양함) + */ + public double calculateILD(List itemVectors) { + if (itemVectors == null || itemVectors.size() <= 1) { + return 0.0; + } + + List validVectors = itemVectors.stream() + .filter(Objects::nonNull) + .filter(v -> v.length > 0) + .toList(); + + if (validVectors.size() <= 1) { + return 0.0; + } + + double totalSimilarity = 0.0; + int pairCount = 0; + + // 모든 아이템 쌍의 유사도 계산 + for (int i = 0; i < validVectors.size(); i++) { + for (int j = i + 1; j < validVectors.size(); j++) { + double similarity = VectorUtil.cosineSimilarity(validVectors.get(i), validVectors.get(j)); + totalSimilarity += similarity; + pairCount++; + } + } + + if (pairCount == 0) { + return 0.0; + } + + double avgSimilarity = totalSimilarity / pairCount; + return 1.0 - avgSimilarity; // 비유사도 = 1 - 유사도 + } + + /** + * DCG (Discounted Cumulative Gain) 계산 + */ + private double calculateDCG(List recommendedIds, Map relevanceScores, int k) { + double dcg = 0.0; + + for (int i = 0; i < Math.min(recommendedIds.size(), k); i++) { + Long postId = recommendedIds.get(i); + int relevance = relevanceScores.getOrDefault(postId, 0); + + if (relevance > 0) { + // DCG 공식: rel_i / log2(i + 2) + dcg += relevance / (Math.log(i + 2) / Math.log(2)); + } + } + return dcg; + } + + /** + * IDCG (Ideal DCG) 계산 + * 이상적인 순서(관련도가 높은 순)로 정렬했을 때의 DCG + */ + private double calculateIDCG(Map relevanceScores, int k) { + List idealRelevances = relevanceScores.values().stream() + .sorted(Comparator.reverseOrder()) + .limit(k) + .toList(); + + double idcg = 0.0; + for (int i = 0; i < idealRelevances.size(); i++) { + int relevance = idealRelevances.get(i); + if (relevance > 0) { + idcg += relevance / (Math.log(i + 2) / Math.log(2)); + } + } + return idcg; + } +} diff --git a/src/test/java/com/techfork/domain/recommendation/evaluation/RecommendationTestBase.java b/src/test/java/com/techfork/domain/recommendation/evaluation/RecommendationTestBase.java new file mode 100644 index 0000000..2d9b0ed --- /dev/null +++ b/src/test/java/com/techfork/domain/recommendation/evaluation/RecommendationTestBase.java @@ -0,0 +1,271 @@ +package com.techfork.domain.recommendation.evaluation; + +import co.elastic.clients.elasticsearch.ElasticsearchClient; +import com.techfork.domain.activity.repository.ReadPostRepository; +import com.techfork.domain.post.repository.PostDocumentRepository; +import com.techfork.domain.post.repository.PostRepository; +import com.techfork.domain.recommendation.config.RecommendationProperties; +import com.techfork.domain.recommendation.repository.RecommendationHistoryRepository; +import com.techfork.domain.recommendation.repository.RecommendedPostRepository; +import com.techfork.domain.recommendation.service.LlmRecommendationService; +import com.techfork.domain.recommendation.service.MmrService; +import com.techfork.domain.recommendation.util.EvaluationFixtureLoader; +import com.techfork.domain.user.entity.User; +import com.techfork.domain.user.enums.EInterestCategory; +import com.techfork.domain.user.repository.UserProfileDocumentRepository; +import com.techfork.global.common.IntegrationTestBase; +import com.techfork.global.elasticsearch.query.VectorQueryBuilder; +import com.techfork.global.util.TimeDecayStrategy; +import com.techfork.global.util.VectorUtil; +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.extern.slf4j.Slf4j; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.TestInstance; +import org.springframework.beans.factory.annotation.Autowired; + +import java.io.IOException; +import java.util.*; + +/** + * 추천 시스템 테스트를 위한 공통 베이스 클래스 + */ +@Slf4j +@TestInstance(TestInstance.Lifecycle.PER_CLASS) +public abstract class RecommendationTestBase extends IntegrationTestBase { + + // 테스트 상수 + protected static final int K_FIRST_ROW = 4; // 첫 줄 + protected static final int K_FIRST_SCREEN = 8; // 첫 화면 + protected static final int K_DEEP_EXPLORE = 30; // 깊은 탐색 + + protected static final float DEFAULT_TITLE_WEIGHT = 0.4f; + protected static final float DEFAULT_SUMMARY_WEIGHT = 0.4f; + protected static final float DEFAULT_CONTENT_WEIGHT = 0.2f; + + protected static final double RECALL_WEIGHT = 0.4; + protected static final double NDCG_WEIGHT = 0.4; + protected static final double ILD_WEIGHT = 0.2; + + @Autowired protected EvaluationFixtureLoader fixtureLoader; + @Autowired protected RecommendationQualityService qualityService; + @Autowired protected PostDocumentRepository postDocumentRepository; + @Autowired protected ElasticsearchClient elasticsearchClient; + @Autowired protected UserProfileDocumentRepository userProfileDocumentRepository; + @Autowired protected RecommendedPostRepository recommendedPostRepository; + @Autowired protected RecommendationHistoryRepository recommendationHistoryRepository; + @Autowired protected ReadPostRepository readPostRepository; + @Autowired protected PostRepository postRepository; + @Autowired protected TimeDecayStrategy timeDecayStrategy; + @Autowired protected VectorQueryBuilder vectorQueryBuilder; + @Autowired protected com.techfork.domain.user.repository.UserRepository userRepository; + + protected static List cachedTestUsers; + protected static Map> cachedGroundTruth; + private static boolean fixturesLoaded = false; + + @BeforeAll + void loadFixtures() { + if (!fixturesLoaded) { + log.info("===== Fixture 데이터 로드 시작 ====="); + cachedGroundTruth = fixtureLoader.loadAll(); + fixturesLoaded = true; + log.info("===== Fixture 데이터 로드 완료: {} 명 =====", cachedGroundTruth.size()); + } + } + + @Data + @Builder + @AllArgsConstructor + protected static class ConfigCombo { + String name; + float titleWeight; + float summaryWeight; + float contentWeight; + double mmrLambda; + } + + @Data + @Builder + @AllArgsConstructor + protected static class EvaluationResult { + String configName; + double avgRecall4; double avgNdcg4; + double avgRecall8; double avgNdcg8; + double avgRecall30; double avgNdcg30; + double avgIld; + double compositeScore; + } + + @Data + @AllArgsConstructor + protected static class UserMetrics { + double recall4; double ndcg4; + double recall8; double ndcg8; + double recall30; double ndcg30; + double ild; + } + + protected List getTestUsers() { + if (cachedTestUsers == null) { + List testUserIds = new ArrayList<>(cachedGroundTruth.keySet()); + cachedTestUsers = userRepository.findAllWithInterestCategoriesByIds(testUserIds); + } + return cachedTestUsers; + } + + protected double calculateCompositeScore(double recall, double ndcg, double ild) { + return recall * RECALL_WEIGHT + ndcg * NDCG_WEIGHT + ild * ILD_WEIGHT; + } + + protected RecommendationProperties createProperties(float tw, float sw, float cw, double lambda) { + RecommendationProperties props = new RecommendationProperties(); + props.setKnnSearchSize(100); + props.setNumCandidates(200); + props.setMmrFinalSize(30); + props.setLambda(lambda); + props.setActiveUserHours(24); + RecommendationProperties.EmbeddingWeights weights = new RecommendationProperties.EmbeddingWeights(); + weights.setTitle(tw); weights.setSummary(sw); weights.setContent(cw); + props.setEmbeddingWeights(weights); + return props; + } + + protected LlmRecommendationService createRecommendationService(RecommendationProperties props) { + MmrService mmrService = new MmrService(props); + return new LlmRecommendationService( + elasticsearchClient, userProfileDocumentRepository, recommendedPostRepository, + recommendationHistoryRepository, readPostRepository, postRepository, + mmrService, timeDecayStrategy, props, vectorQueryBuilder + ); + } + + protected EvaluationResult calculateAverageMetrics(String configName, List metrics) { + double r4 = metrics.stream().mapToDouble(UserMetrics::getRecall4).average().orElse(0.0); + double n4 = metrics.stream().mapToDouble(UserMetrics::getNdcg4).average().orElse(0.0); + double r8 = metrics.stream().mapToDouble(UserMetrics::getRecall8).average().orElse(0.0); + double n8 = metrics.stream().mapToDouble(UserMetrics::getNdcg8).average().orElse(0.0); + double r30 = metrics.stream().mapToDouble(UserMetrics::getRecall30).average().orElse(0.0); + double n30 = metrics.stream().mapToDouble(UserMetrics::getNdcg30).average().orElse(0.0); + double ild = metrics.stream().mapToDouble(UserMetrics::getIld).average().orElse(0.0); + double score = calculateCompositeScore(r8, n8, ild); + + return EvaluationResult.builder() + .configName(configName) + .avgRecall4(r4).avgNdcg4(n4) + .avgRecall8(r8).avgNdcg8(n8) + .avgRecall30(r30).avgNdcg30(n30) + .avgIld(ild).compositeScore(score) + .build(); + } + + /** + * ILD 없이 일반 평가 (가중치 비교용) + */ + protected Optional evaluateUserWithGroundTruth(User user, LlmRecommendationService service) { + try { + Map groundTruth = cachedGroundTruth.get(user.getId()); + if (groundTruth == null || groundTruth.isEmpty()) return Optional.empty(); + + Set readIds = readPostRepository.findRecentReadPostsByUserIdWithMinDuration(user.getId(), org.springframework.data.domain.PageRequest.of(0, 10000)) + .stream().map(rp -> rp.getPost().getId()).collect(java.util.stream.Collectors.toSet()); + + List recIds = service.generateRecommendationsForEvaluation(user, readIds); + if (recIds.isEmpty()) return Optional.empty(); + + double r4 = qualityService.calculateRecall(recIds, groundTruth.keySet(), K_FIRST_ROW); + double n4 = qualityService.calculateNDCG(recIds, groundTruth, K_FIRST_ROW); + double r8 = qualityService.calculateRecall(recIds, groundTruth.keySet(), K_FIRST_SCREEN); + double n8 = qualityService.calculateNDCG(recIds, groundTruth, K_FIRST_SCREEN); + double r30 = qualityService.calculateRecall(recIds, groundTruth.keySet(), K_DEEP_EXPLORE); + double n30 = qualityService.calculateNDCG(recIds, groundTruth, K_DEEP_EXPLORE); + + return Optional.of(new UserMetrics(r4, n4, r8, n8, r30, n30, 0.0)); + } catch (Exception e) { + return Optional.empty(); + } + } + + /** + * ILD 포함 평가 (Lambda 최적화용) + */ + protected Optional evaluateUserWithGroundTruthAndILD(User user, LlmRecommendationService service) { + try { + Map groundTruth = cachedGroundTruth.get(user.getId()); + if (groundTruth == null || groundTruth.isEmpty()) return Optional.empty(); + + Set readIds = readPostRepository.findRecentReadPostsByUserIdWithMinDuration(user.getId(), org.springframework.data.domain.PageRequest.of(0, 10000)) + .stream().map(rp -> rp.getPost().getId()).collect(java.util.stream.Collectors.toSet()); + + List recIds = service.generateRecommendationsForEvaluation(user, readIds); + if (recIds.isEmpty()) return Optional.empty(); + + double r4 = qualityService.calculateRecall(recIds, groundTruth.keySet(), K_FIRST_ROW); + double n4 = qualityService.calculateNDCG(recIds, groundTruth, K_FIRST_ROW); + double r8 = qualityService.calculateRecall(recIds, groundTruth.keySet(), K_FIRST_SCREEN); + double n8 = qualityService.calculateNDCG(recIds, groundTruth, K_FIRST_SCREEN); + double r30 = qualityService.calculateRecall(recIds, groundTruth.keySet(), K_DEEP_EXPLORE); + double n30 = qualityService.calculateNDCG(recIds, groundTruth, K_DEEP_EXPLORE); + + List vectors = recIds.stream().limit(K_FIRST_SCREEN) + .map(id -> postDocumentRepository.findByPostId(id).map(d -> VectorUtil.convertToFloatArray(d.getSummaryEmbedding())).orElse(null)) + .filter(Objects::nonNull).toList(); + double ild = qualityService.calculateILD(vectors); + + return Optional.of(new UserMetrics(r4, n4, r8, n8, r30, n30, ild)); + } catch (Exception e) { + log.warn("사용자 {} 평가 중 오류: {}", user.getId(), e.getMessage()); + return Optional.empty(); + } + } + + protected EvaluationResult evaluateConfigWithGroundTruth(ConfigCombo config, List testUsers) { + LlmRecommendationService service = createRecommendationService( + createProperties(config.getTitleWeight(), config.getSummaryWeight(), config.getContentWeight(), config.getMmrLambda())); + + List metrics = testUsers.stream() + .map(user -> evaluateUserWithGroundTruth(user, service)) + .filter(Optional::isPresent) + .map(Optional::get) + .toList(); + + return calculateAverageMetrics(config.getName(), metrics); + } + + protected EvaluationResult evaluateConfigWithGroundTruthAndILD(ConfigCombo config, List testUsers) { + LlmRecommendationService service = createRecommendationService( + createProperties(config.getTitleWeight(), config.getSummaryWeight(), config.getContentWeight(), config.getMmrLambda())); + + List metrics = testUsers.stream() + .map(user -> evaluateUserWithGroundTruthAndILD(user, service)) + .filter(Optional::isPresent) + .map(Optional::get) + .toList(); + + return calculateAverageMetrics(config.getName(), metrics); + } + + protected void printWeightComparisonHeader() { + log.info(String.format("\n%-25s | %-10s | %-10s | %-10s | %-10s | %-10s | %-10s", + "설정", "R@4", "R@8", "R@30", "nDCG@4", "nDCG@8", "nDCG@30")); + log.info("-".repeat(100)); + } + + protected void printWeightComparisonResult(EvaluationResult result) { + log.info(String.format("%-25s | %.4f | %.4f | %.4f | %.4f | %.4f | %.4f", + result.getConfigName(), + result.getAvgRecall4(), result.getAvgRecall8(), result.getAvgRecall30(), + result.getAvgNdcg4(), result.getAvgNdcg8(), result.getAvgNdcg30())); + } + + protected void printLambdaOptimizationHeader() { + log.info(String.format("\n%-25s | %-10s | %-10s | %-10s | %-10s", "설정", "R@8", "nDCG@8", "ILD", "Composite")); + log.info("-".repeat(75)); + } + + protected void printLambdaOptimizationResult(EvaluationResult result) { + log.info(String.format("%-25s | %.4f | %.4f | %.4f | %.4f", + result.getConfigName(), result.getAvgRecall8(), result.getAvgNdcg8(), result.getAvgIld(), result.getCompositeScore())); + } +} diff --git a/src/test/java/com/techfork/domain/recommendation/evaluation/TitleSummaryRatioOptimizationTest.java b/src/test/java/com/techfork/domain/recommendation/evaluation/TitleSummaryRatioOptimizationTest.java new file mode 100644 index 0000000..59fcc43 --- /dev/null +++ b/src/test/java/com/techfork/domain/recommendation/evaluation/TitleSummaryRatioOptimizationTest.java @@ -0,0 +1,96 @@ +package com.techfork.domain.recommendation.evaluation; + +import com.techfork.domain.user.entity.User; +import lombok.extern.slf4j.Slf4j; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Tag; +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; + +/** + * 제목(Title)과 요약(Summary)의 가중치 비율을 최적화하는 테스트 + * 본문(Content) 가중치는 0으로 고정하고 테스트합니다. + */ +@Tag("evaluation") +@Slf4j +public class TitleSummaryRatioOptimizationTest extends RecommendationTestBase { + + @Test + @DisplayName("제목 vs 요약 가중치 최적화 (본문 제외)") + void optimizeTitleSummaryRatio() { + log.info("===== 제목 vs 요약 가중치 최적화 테스트 ====="); + log.info("조건: 본문(Content) 가중치 = 0.0, Lambda = 1.0 (순수 관련성)"); + log.info("범위: 제목 0.0~1.0, 요약 1.0~0.0 (0.1 단위)"); + + List configs = createRatioTestConfigs(); + List testUsers = getTestUsers(); + log.info("테스트 사용자: {} 명", testUsers.size()); + + printWeightComparisonHeader(); + List results = new ArrayList<>(); + + for (ConfigCombo config : configs) { + EvaluationResult result = evaluateConfigWithGroundTruth(config, testUsers); + results.add(result); + printWeightComparisonResult(result); + } + + printBestRatioResult(results); + } + + /** + * 제목 0.0~1.0 (0.1 단위) 비율 설정 생성 + */ + private List createRatioTestConfigs() { + List configs = new ArrayList<>(); + + for (int i = 0; i <= 10; i++) { + float titleWeight = i / 10.0f; + float summaryWeight = 1.0f - titleWeight; + + configs.add(ConfigCombo.builder() + .name(String.format("T:%.1f / S:%.1f", titleWeight, summaryWeight)) + .titleWeight(titleWeight) + .summaryWeight(summaryWeight) + .contentWeight(0.0f) + .mmrLambda(1.0) + .build()); + } + + return configs; + } + + /** + * 최적 비율 결과 출력 + */ + private void printBestRatioResult(List results) { + log.info(" ===== 제목 vs 요약 최적 비율 결과 (K=8 기준) ====="); + + // nDCG@8 최고 기준 + EvaluationResult bestNdcg8 = results.stream() + .max(Comparator.comparingDouble(EvaluationResult::getAvgNdcg8)) + .orElse(null); + + if (bestNdcg8 != null) { + log.info(" [nDCG@8 최고]"); + log.info(String.format("최적 설정: %s", bestNdcg8.getConfigName())); + log.info(String.format("성능: R@8: %.4f, nDCG@8: %.4f", + bestNdcg8.getAvgRecall8(), bestNdcg8.getAvgNdcg8())); + } + + // Recall@8 최고 기준 + EvaluationResult bestRecall8 = results.stream() + .max(Comparator.comparingDouble(EvaluationResult::getAvgRecall8)) + .orElse(null); + + if (bestRecall8 != null && (bestNdcg8 == null || !bestRecall8.getConfigName().equals(bestNdcg8.getConfigName()))) { + log.info(" [Recall@8 최고]"); + log.info(String.format("최적 설정: %s", bestRecall8.getConfigName())); + log.info(String.format("성능: R@8: %.4f, nDCG@8: %.4f", + bestRecall8.getAvgRecall8(), bestRecall8.getAvgNdcg8())); + } + } +} diff --git a/src/test/java/com/techfork/domain/recommendation/setup/PostDataExporter.java b/src/test/java/com/techfork/domain/recommendation/setup/PostDataExporter.java new file mode 100644 index 0000000..178991b --- /dev/null +++ b/src/test/java/com/techfork/domain/recommendation/setup/PostDataExporter.java @@ -0,0 +1,120 @@ +package com.techfork.domain.recommendation.setup; + +import com.techfork.domain.post.document.PostDocument; +import com.techfork.domain.post.entity.Post; +import com.techfork.domain.post.repository.PostDocumentRepository; +import com.techfork.domain.post.repository.PostRepository; +import com.techfork.domain.recommendation.setup.components.FileExporter; +import lombok.extern.slf4j.Slf4j; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Tag; +import org.junit.jupiter.api.Test; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.data.domain.PageRequest; +import org.springframework.data.domain.Sort; +import org.springframework.test.context.ActiveProfiles; +import org.springframework.transaction.annotation.Transactional; + +import java.io.IOException; +import java.util.*; +import java.util.stream.Collectors; + +/** + * 원격 DB에서 게시글 데이터만 JSON으로 export + */ +@Tag("evaluation-setup") +@Disabled("수동 실행용 - CI 제외") +@Slf4j +@SpringBootTest +@ActiveProfiles("local-tunnel") +class PostDataExporter { + + @Autowired + private PostRepository postRepository; + + @Autowired + private PostDocumentRepository postDocumentRepository; + + @Autowired + private FileExporter fileExporter; + + private static final int POST_EXPORT_COUNT = 1200; // 전체 데이터 사용 (약 1100개) + + @Test + @DisplayName("원격 DB에서 게시글 데이터 Export") + @Transactional + public void exportPostData() throws IOException { + log.info("===== 게시글 데이터 Export 시작 ====="); + log.info("원격 DB: MySQL + Elasticsearch"); + + fileExporter.ensureOutputDirectory(); + + List posts = exportPosts(); + log.info("✓ 게시글 {} 개 export 완료", posts.size()); + + Set postIds = posts.stream() + .map(Post::getId) + .collect(Collectors.toSet()); + List postDocuments = exportPostDocuments(postIds); + log.info("✓ PostDocument {} 개 export 완료 (임베딩 포함)", postDocuments.size()); + + log.info("===== 게시글 데이터 Export 완료 ====="); + log.info("출력 위치: {}", fileExporter.getOutputDir()); + log.info("\n생성된 파일:"); + log.info(" - posts.json ({} 개)", posts.size()); + log.info(" - post-documents.json ({} 개, titleEmbedding + summaryEmbedding 포함)", postDocuments.size()); + } + + private List exportPosts() throws IOException { + List posts = postRepository.findAll( + PageRequest.of(0, POST_EXPORT_COUNT, Sort.by("publishedAt").descending()) + ).getContent(); + + List> postDtos = posts.stream() + .map(this::convertPostToDto) + .toList(); + + fileExporter.writeJsonFile("posts.json", postDtos); + return posts; + } + + private List exportPostDocuments(Set postIds) throws IOException { + List documents = new ArrayList<>(); + + for (Long postId : postIds) { + Optional docOpt = postDocumentRepository.findByPostId(postId); + if (docOpt.isPresent()) { + documents.add(docOpt.get()); + } else { + log.warn("PostDocument not found for postId: {}", postId); + } + } + + fileExporter.writeJsonFile("post-documents.json", documents); + + return documents; + } + + private Map convertPostToDto(Post post) { + Map dto = new HashMap<>(); + dto.put("id", post.getId()); + dto.put("title", post.getTitle()); + dto.put("url", post.getUrl()); + dto.put("summary", post.getSummary()); + dto.put("shortSummary", post.getShortSummary()); + dto.put("company", post.getCompany()); + dto.put("logoUrl", post.getLogoUrl()); + dto.put("thumbnailUrl", post.getThumbnailUrl()); + dto.put("publishedAt", post.getPublishedAt() != null ? post.getPublishedAt().toString() : null); + dto.put("viewCount", post.getViewCount()); + + dto.put("techBlogId", post.getTechBlog().getId()); + dto.put("techBlogCompanyName", post.getTechBlog().getCompanyName()); + dto.put("techBlogUrl", post.getTechBlog().getBlogUrl()); + dto.put("techBlogRssUrl", post.getTechBlog().getRssUrl()); + + return dto; + } +} \ No newline at end of file diff --git a/src/test/java/com/techfork/domain/recommendation/setup/UserDataSetupAndExporter.java b/src/test/java/com/techfork/domain/recommendation/setup/UserDataSetupAndExporter.java new file mode 100644 index 0000000..6d5c301 --- /dev/null +++ b/src/test/java/com/techfork/domain/recommendation/setup/UserDataSetupAndExporter.java @@ -0,0 +1,314 @@ +package com.techfork.domain.recommendation.setup; + +import com.techfork.domain.activity.entity.ReadPost; +import com.techfork.domain.activity.repository.ReadPostRepository; +import com.techfork.domain.post.entity.Post; +import com.techfork.domain.recommendation.setup.components.FileExporter; +import com.techfork.domain.recommendation.setup.components.TestDataGenerator; +import com.techfork.domain.recommendation.setup.components.TestDataGenerator.UserCreationResult; +import com.techfork.domain.recommendation.util.EvaluationFixtureLoader; +import com.techfork.domain.user.document.UserProfileDocument; +import com.techfork.domain.user.entity.User; +import com.techfork.domain.user.enums.EInterestCategory; +import com.techfork.domain.user.repository.UserProfileDocumentRepository; +import com.techfork.domain.user.repository.UserRepository; +import com.techfork.global.common.IntegrationTestBase; +import lombok.extern.slf4j.Slf4j; +import org.junit.jupiter.api.*; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.test.annotation.Commit; +import org.springframework.transaction.annotation.Transactional; + +import java.io.IOException; +import java.util.*; +import java.util.stream.Collectors; + +/** + * Testcontainer에서 사용자 데이터 생성 및 JSON으로 export + *

+ * 실행 방법: + * ./gradlew test --tests UserDataSetupAndExporter + *

+ * 주의: 게시글 데이터가 먼저 로드되어 있어야 합니다 (PostDataExporter 먼저 실행 필요) + */ +@Tag("evaluation-setup") +@Disabled("수동 실행용 - CI 제외") +@Slf4j +@TestMethodOrder(MethodOrderer.OrderAnnotation.class) +public class UserDataSetupAndExporter extends IntegrationTestBase { + + @Autowired + private UserRepository userRepository; + + @Autowired + private ReadPostRepository readPostRepository; + + @Autowired + private UserProfileDocumentRepository userProfileDocumentRepository; + + @Autowired + private TestDataGenerator testDataGenerator; + + @Autowired + private EvaluationFixtureLoader fixtureLoader; + + @Autowired + private FileExporter fileExporter; + + private static final int USER_COUNT = 5; + private static final int READ_POST_COUNT = 80; // 프로필 구성용 (읽은 글) - 1100개 데이터셋 기준 (약 7%) + private static final int HOLDOUT_COUNT = 30; // Ground Truth (평가용, 숨김) - 평가 샘플 (약 2.7%) + + // 실제 DB ID (Local) -> 원격 DB ID (Remote) 매핑용 + private static Map actualToRemotePostIdMap = new HashMap<>(); + + @Test + @Order(1) + @DisplayName("STEP 1: 게시글 픽스처 로드 (posts.json, post-documents.json)") + void step1_LoadPostFixtures() { + log.info("===== STEP 1: 게시글 픽스처 로드 ====="); + log.info("주의: PostDataExporter를 먼저 실행하여 게시글 데이터를 export해야 합니다."); + + try { + Map remoteToActualMap = fixtureLoader.loadPostsOnly(); + + // 역매핑 맵 생성 (실제 DB ID -> 원래 원격 ID) + actualToRemotePostIdMap.clear(); + remoteToActualMap.forEach((remoteId, post) -> + actualToRemotePostIdMap.put(post.getId(), remoteId)); + + log.info("✓ 게시글 픽스처 로드 및 ID 매핑 완료 ({} 개)", actualToRemotePostIdMap.size()); + } catch (Exception e) { + log.error("게시글 픽스처 로드 실패. PostDataExporter를 먼저 실행하세요.", e); + throw e; + } + } + + @Test + @Order(2) + @DisplayName("STEP 2: 테스트 사용자 5명 생성 (임베딩 포함)") + @Transactional + @Commit + void step2_CreateTestUsers() throws IOException { + log.info("===== STEP 2: 테스트 사용자 생성 ====="); + + List> interestCombos = Arrays.asList( + Arrays.asList(EInterestCategory.BACKEND), + Arrays.asList(EInterestCategory.FRONTEND), + Arrays.asList(EInterestCategory.AI_ML), + Arrays.asList(EInterestCategory.BACKEND, EInterestCategory.DATABASE), + Arrays.asList(EInterestCategory.AI_ML, EInterestCategory.DATA_SCIENCE) + ); + + Map> userGroundTruthMap = new HashMap<>(); + + for (int i = 0; i < USER_COUNT; i++) { + List interests = interestCombos.get(i); + + log.info("사용자 {}/{} 생성 중... (관심사: {})", i + 1, USER_COUNT, interests); + + UserCreationResult result = testDataGenerator.createTestUserWithGroundTruth(interests, READ_POST_COUNT, HOLDOUT_COUNT); + User user = result.user(); + + userGroundTruthMap.put(user.getId(), result.groundTruthScores()); + + log.info("✓ 사용자 생성 완료: ID={}, 관심사={}, 읽은 글={} 개, Ground Truth={} 개", + user.getId(), interests, READ_POST_COUNT, result.groundTruthScores().size()); + } + + log.info("===== STEP 2 완료: {} 명 사용자 생성 완료 =====\n", USER_COUNT); + + List users = userRepository.findAll(); + log.info("총 생성된 사용자: {} 명", users.size()); + + long profileCount = users.stream() + .filter(u -> userProfileDocumentRepository.findByUserId(u.getId()).isPresent()) + .count(); + log.info("UserProfile(임베딩) 생성된 사용자: {} 명", profileCount); + + // ID 변환: 실제 DB ID -> 원격 DB ID + Map> convertedGroundTruthMap = new HashMap<>(); + userGroundTruthMap.forEach((userId, scores) -> { + Map convertedScores = new HashMap<>(); + scores.forEach((actualPostId, score) -> { + Long remoteId = actualToRemotePostIdMap.get(actualPostId); + if (remoteId != null) { + convertedScores.put(remoteId, score); + } else { + log.warn("Ground Truth Post ID 매핑 실패: actualPostId={}", actualPostId); + } + }); + convertedGroundTruthMap.put(userId, convertedScores); + }); + + fileExporter.writeJsonFile("ground-truth.json", convertedGroundTruthMap); + log.info("✓ Ground Truth {} 명 export 완료 (원격 ID 변환 적용)", convertedGroundTruthMap.size()); + } + + @Test + @Order(3) + @DisplayName("STEP 3: 사용자 데이터를 JSON으로 Export") + @Transactional(readOnly = true) + void step3_ExportUserData() throws IOException { + log.info("===== STEP 3: 사용자 데이터 Export 시작 ====="); + + fileExporter.ensureOutputDirectory(); + + List users = exportUsers(); + log.info("✓ 사용자 {} 명 export 완료", users.size()); + + List readPosts = exportReadPosts(users); + log.info("✓ 읽은 글 이력 {} 개 export 완료", readPosts.size()); + + List userProfiles = exportUserProfiles(users); + log.info("✓ UserProfileDocument {} 개 export 완료 (임베딩 포함)", userProfiles.size()); + + log.info("===== STEP 3 완료 ====="); + log.info("출력 위치: {}", fileExporter.getOutputDir()); + log.info("\n생성된 파일:"); + log.info(" - users.json ({} 명)", users.size()); + log.info(" - read-posts.json ({} 개)", readPosts.size()); + log.info(" - user-profiles.json ({} 개, profileVector 3072차원 포함)", userProfiles.size()); + log.info("\nSTEP 2에서 이미 생성된 파일:"); + log.info(" - ground-truth.json (사용자별 정답 게시글 ID + 관련도 점수)"); + } + + private List exportUsers() throws IOException { + List users = userRepository.findAll(); + + // DTO 변환 (순환 참조 방지) + List> userDtos = users.stream() + .map(this::convertUserToDto) + .collect(Collectors.toList()); + + fileExporter.writeJsonFile("users.json", userDtos); + return users; + } + + private List exportReadPosts(List users) throws IOException { + List allReadPosts = new ArrayList<>(); + + for (User user : users) { + List userReadPosts = readPostRepository + .findRecentReadPostsByUserIdWithMinDuration( + user.getId(), + org.springframework.data.domain.Pageable.unpaged() + ); + allReadPosts.addAll(userReadPosts); + } + + // DTO 변환 + List> readPostDtos = allReadPosts.stream() + .map(this::convertReadPostToDto) + .collect(Collectors.toList()); + + fileExporter.writeJsonFile("read-posts.json", readPostDtos); + return allReadPosts; + } + + private List exportUserProfiles(List users) throws IOException { + List profiles = new ArrayList<>(); + int notFoundCount = 0; + + for (User user : users) { + Optional profileOpt = + userProfileDocumentRepository.findByUserId(user.getId()); + + if (profileOpt.isPresent()) { + profiles.add(profileOpt.get()); + } else { + notFoundCount++; + log.warn("UserProfileDocument not found for userId: {}", user.getId()); + } + } + + if (notFoundCount > 0) { + log.warn("총 {} 명의 UserProfileDocument를 찾지 못했습니다.", notFoundCount); + } + + // DTO 변환 (profileVector는 float[]이므로 List로 변환) + List> profileDtos = profiles.stream() + .map(this::convertUserProfileToDto) + .collect(Collectors.toList()); + + fileExporter.writeJsonFile("user-profiles.json", profileDtos); + + // 임베딩 차원 검증 + if (!profiles.isEmpty()) { + UserProfileDocument sample = profiles.get(0); + log.info("임베딩 차원 검증:"); + log.info(" - profileVector: {} 차원", + sample.getProfileVector() != null ? sample.getProfileVector().length : "null"); + } + + return profiles; + } + + private Map convertUserToDto(User user) { + Map dto = new HashMap<>(); + dto.put("id", user.getId()); + dto.put("email", user.getEmail()); + dto.put("nickname", user.getNickName()); + dto.put("profileImageUrl", user.getProfileImage()); + dto.put("socialType", user.getSocialType().name()); + dto.put("socialId", user.getSocialId()); + dto.put("role", user.getRole().name()); + + // 관심사 카테고리 + if (user.getInterestCategories() != null) { + List> interests = user.getInterestCategories().stream() + .map(ic -> { + Map interestDto = new HashMap<>(); + interestDto.put("category", ic.getCategory().name()); + + // 키워드 + if (ic.getKeywords() != null) { + List keywords = ic.getKeywords().stream() + .map(k -> k.getKeyword().name()) + .collect(Collectors.toList()); + interestDto.put("keywords", keywords); + } + + return interestDto; + }) + .collect(Collectors.toList()); + dto.put("interests", interests); + } + + return dto; + } + + private Map convertReadPostToDto(ReadPost readPost) { + Map dto = new HashMap<>(); + dto.put("userId", readPost.getUser().getId()); + + // 실제 DB ID -> 원격 DB ID로 변환 + Long actualPostId = readPost.getPost().getId(); + Long remotePostId = actualToRemotePostIdMap.get(actualPostId); + dto.put("postId", remotePostId != null ? remotePostId : actualPostId); + + dto.put("readAt", readPost.getReadAt().toString()); + dto.put("readDurationSeconds", readPost.getReadDurationSeconds()); + return dto; + } + + private Map convertUserProfileToDto(UserProfileDocument profile) { + Map dto = new HashMap<>(); + dto.put("id", profile.getId()); + dto.put("userId", profile.getUserId()); + dto.put("profileText", profile.getProfileText()); + + // float[] -> List 변환 + if (profile.getProfileVector() != null) { + List vectorList = new ArrayList<>(); + for (float v : profile.getProfileVector()) { + vectorList.add(v); + } + dto.put("profileVector", vectorList); + } + + dto.put("interests", profile.getInterests()); + + return dto; + } +} diff --git a/src/test/java/com/techfork/domain/recommendation/setup/components/FileExporter.java b/src/test/java/com/techfork/domain/recommendation/setup/components/FileExporter.java new file mode 100644 index 0000000..1b2895b --- /dev/null +++ b/src/test/java/com/techfork/domain/recommendation/setup/components/FileExporter.java @@ -0,0 +1,47 @@ +package com.techfork.domain.recommendation.setup.components; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.SerializationFeature; +import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule; +import lombok.extern.slf4j.Slf4j; +import org.springframework.stereotype.Component; + +import java.io.File; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; + +/** + * JSON 파일 export를 위한 공통 유틸리티 클래스 + */ +@Slf4j +@Component +public class FileExporter { + + private static final String OUTPUT_DIR = "src/test/resources/fixtures/evaluation"; + + private final ObjectMapper objectMapper; + + public FileExporter() { + this.objectMapper = new ObjectMapper() + .registerModule(new JavaTimeModule()) + .enable(SerializationFeature.INDENT_OUTPUT); + } + + public void ensureOutputDirectory() throws IOException { + Path outputPath = Paths.get(OUTPUT_DIR); + Files.createDirectories(outputPath); + log.info("출력 디렉토리: {}", outputPath.toAbsolutePath()); + } + + public String getOutputDir() { + return OUTPUT_DIR; + } + + public void writeJsonFile(String filename, Object data) throws IOException { + File outputFile = new File(OUTPUT_DIR, filename); + objectMapper.writeValue(outputFile, data); + log.debug("파일 작성: {}", outputFile.getAbsolutePath()); + } +} \ No newline at end of file diff --git a/src/test/java/com/techfork/domain/recommendation/setup/components/GroundTruthGenerator.java b/src/test/java/com/techfork/domain/recommendation/setup/components/GroundTruthGenerator.java new file mode 100644 index 0000000..79e07d3 --- /dev/null +++ b/src/test/java/com/techfork/domain/recommendation/setup/components/GroundTruthGenerator.java @@ -0,0 +1,144 @@ +package com.techfork.domain.recommendation.setup.components; + +import com.techfork.domain.post.document.PostDocument; +import com.techfork.domain.post.entity.Post; +import com.techfork.domain.post.repository.PostDocumentRepository; +import com.techfork.domain.user.document.UserProfileDocument; +import com.techfork.global.llm.LlmClient; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.springframework.stereotype.Component; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +/** + * Ground Truth 계산 및 품질 검증 + * LLM-as-a-Judge 방식을 사용하여 정답 데이터를 생성합니다. + */ +@Slf4j +@Component +@RequiredArgsConstructor +public class GroundTruthGenerator { + + private final PostDocumentRepository postDocumentRepository; + private final LlmClient llmClient; + + /** + * Ground Truth 관련도 점수 계산 (LLM 기반) + * + * @param posts 평가할 게시글 목록 + * @param userProfile 사용자 프로필 문서 (프로필 텍스트 포함) + * @return 게시글 ID -> 관련도 점수 (1~5점) + */ + public Map calculateGroundTruth( + List posts, + UserProfileDocument userProfile) { + + Map groundTruthScores = new HashMap<>(); + int count = 0; + + for (Post post : posts) { + count++; + log.info("Ground Truth 평가 중 ({}/{}): Post ID {}", count, posts.size(), post.getId()); + + try { + int score = calculateRelevanceScoreWithLLM(post, userProfile); + groundTruthScores.put(post.getId(), score); + } catch (Exception e) { + log.error("LLM 평가 실패 (Post ID {}): {}", post.getId(), e.getMessage()); + // 실패 시 기본 점수 1점 부여 (안전장치) + groundTruthScores.put(post.getId(), 1); + } + } + + return groundTruthScores; + } + + /** + * LLM을 사용하여 게시글의 관련도 점수 평가 (1~5점) + */ + private int calculateRelevanceScoreWithLLM(Post post, UserProfileDocument userProfile) { + // PostDocument에서 더 풍부한 정보 가져오기 (요약문, 본문 청크 등) + Optional postDocOpt = postDocumentRepository.findByPostId(post.getId()); + + String postSummary = postDocOpt.map(PostDocument::getSummary).orElse(post.getSummary()); + + // 본문 내용 일부 가져오기 (Content Chunks가 있다면 도입부와 결론부 위주로 추출) + StringBuilder contentContext = new StringBuilder(); + if (postDocOpt.isPresent() && postDocOpt.get().getContentChunks() != null) { + var chunks = postDocOpt.get().getContentChunks(); + if (!chunks.isEmpty()) { + // 1. 도입부 (첫 번째 청크) + String intro = chunks.get(0).getChunkText(); + contentContext.append("[도입부]\n") + .append(intro.substring(0, Math.min(intro.length(), 1500))) + .append("\n\n"); + + // 2. 결론부 (마지막 청크 - 도입부와 다른 경우에만) + if (chunks.size() > 1) { + String conclusion = chunks.get(chunks.size() - 1).getChunkText(); + contentContext.append("[결론 및 요약]\n") + .append(conclusion.substring(0, Math.min(conclusion.length(), 1500))); + } + } + } + + String systemPrompt = "당신은 기술 블로그 추천 시스템의 품질 평가자(Judge)입니다. 사용자 프로필과 게시글 내용을 바탕으로 적합성을 1~5점 척도로 평가하세요."; + + String userPrompt = String.format(""" + 다음 사용자가 해당 게시글을 추천받았을 때 얼마나 만족할지 평가해주세요. + + ## 사용자 프로필 + %s + + ## 게시글 정보 + - 제목: %s + - 회사/블로그: %s + - 요약: %s + - 본문 내용(일부): + %s + + ## 평가 기준 + 5점 (매우 강한 추천): 사용자의 핵심 관심사(주력 기술, 해결하려는 문제)와 정확히 일치하며, 반드시 읽어야 할 글. + 4점 (추천): 사용자의 관심사와 밀접하게 관련되어 있으며, 흥미를 느낄 만한 글. + 3점 (보통): 사용자의 관심사와 관련은 있으나, 핵심 분야가 아니거나 너무 일반적인 내용. + 2점 (약간 관련): 키워드는 일부 겹치지만, 사용자의 주된 관심사와 거리가 먼 글. + 1점 (관련 없음): 사용자의 관심사와 전혀 무관한 글. + + ## 응답 형식 + 반드시 점수(숫자 1~5)만 출력하세요. 설명은 필요 없습니다. + """, + userProfile.getProfileText(), + post.getTitle(), + post.getCompany(), + postSummary, + contentContext.length() > 0 ? contentContext.toString() : "(본문 데이터 없음)" + ); + + String response = llmClient.call(systemPrompt, userPrompt); + return parseScore(response); + } + + private int parseScore(String response) { + // 응답에서 숫자만 추출 + try { + // "점수: 5" 같은 형식에 대비 + Matcher matcher = Pattern.compile("(\\d)").matcher(response); + if (matcher.find()) { + int score = Integer.parseInt(matcher.group(1)); + return Math.max(1, Math.min(5, score)); // 1~5 범위 제한 + } + + // 숫자를 못 찾은 경우 + log.warn("LLM 응답에서 점수를 파싱할 수 없음: {}", response); + return 1; + } catch (NumberFormatException e) { + return 1; + } + } +} diff --git a/src/test/java/com/techfork/domain/recommendation/setup/components/GroundTruthValidator.java b/src/test/java/com/techfork/domain/recommendation/setup/components/GroundTruthValidator.java new file mode 100644 index 0000000..414b4a4 --- /dev/null +++ b/src/test/java/com/techfork/domain/recommendation/setup/components/GroundTruthValidator.java @@ -0,0 +1,65 @@ +package com.techfork.domain.recommendation.setup.components; + +import com.techfork.domain.user.enums.EInterestCategory; +import lombok.extern.slf4j.Slf4j; +import org.springframework.stereotype.Component; + +import java.util.List; +import java.util.Map; + +@Slf4j +@Component +public class GroundTruthValidator { + + /** + * Ground Truth 품질 검증 + * - 최소 3점 이상 글이 충분한지 + * - 점수 분포가 편향되지 않았는지 + */ + public void validateGroundTruthQuality( + Map groundTruthScores, + List interests) { + + if (groundTruthScores.isEmpty()) { + log.error("Ground Truth가 비어있습니다!"); + throw new IllegalStateException("Ground Truth가 비어있습니다."); + } + + int totalCount = groundTruthScores.size(); + long highQualityCount = groundTruthScores.values().stream() + .filter(score -> score >= 3) + .count(); + + double highQualityRatio = (double) highQualityCount / totalCount; + + log.info("===== Ground Truth 품질 검증 ====="); + log.info("총 개수: {}", totalCount); + log.info("3점 이상: {} 개 ({}%)", highQualityCount, String.format("%.1f", highQualityRatio * 100)); + + // 경고: 3점 이상이 50% 미만 + if (highQualityRatio < 0.5) { + log.warn("⚠️ 경고: 3점 이상 비율이 낮습니다 ({}%). 관심사와 맞는 글이 부족할 수 있습니다.", + String.format("%.1f", highQualityRatio * 100)); + log.warn("관심사: {}", interests); + } + + // 에러: 3점 이상이 20% 미만 + if (highQualityRatio < 0.2) { + log.error("❌ Ground Truth 품질이 너무 낮습니다. 관심사와 맞는 게시글이 부족합니다."); + throw new IllegalStateException( + String.format("Ground Truth 품질 불량: 3점 이상 비율 %.1f%% (최소 20%% 필요)", + highQualityRatio * 100)); + } + + // 최고 점수 확인 + int maxScore = groundTruthScores.values().stream() + .max(Integer::compareTo) + .orElse(0); + + if (maxScore < 3) { + log.warn("⚠️ 경고: 최고 점수가 {}점입니다. 관련도 높은 글이 없을 수 있습니다.", maxScore); + } + + log.info("✓ Ground Truth 품질 검증 통과"); + } +} diff --git a/src/test/java/com/techfork/domain/recommendation/setup/components/PostMatcher.java b/src/test/java/com/techfork/domain/recommendation/setup/components/PostMatcher.java new file mode 100644 index 0000000..dfbfbea --- /dev/null +++ b/src/test/java/com/techfork/domain/recommendation/setup/components/PostMatcher.java @@ -0,0 +1,121 @@ +package com.techfork.domain.recommendation.setup.components; + +import com.techfork.domain.post.entity.Post; +import com.techfork.domain.post.repository.PostRepository; +import com.techfork.domain.user.enums.EInterestCategory; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.springframework.stereotype.Component; +import org.springframework.transaction.annotation.Transactional; + +import java.util.*; +import java.util.stream.Collectors; + +/** + * 관심사 카테고리 기반 게시글 매칭 + */ +@Slf4j +@Component +@RequiredArgsConstructor +public class PostMatcher { + + private final PostRepository postRepository; + + // 관심사 카테고리별 키워드 매핑 + private static final Map> INTEREST_KEYWORDS = createKeywordMap(); + + private static Map> createKeywordMap() { + Map> map = new HashMap<>(); + + map.put(EInterestCategory.BACKEND, Arrays.asList( + "Spring", "Java", "Kotlin", "API", "서버", "Backend", "백엔드", "JPA", "Hibernate" + )); + + map.put(EInterestCategory.AI_ML, Arrays.asList( + "AI", "ML", "머신러닝", "딥러닝", "LLM", "GPT", "인공지능", "모델", "학습" + )); + + map.put(EInterestCategory.FRONTEND, Arrays.asList( + "React", "Vue", "JavaScript", "CSS", "UI", "Frontend", "프론트엔드", + "TypeScript", "HTML", "웹", "브라우저", "Node.js", "Next", "Angular", + "Webpack", "번들", "SPA", "SSR", "렌더링", "컴포넌트", "디자인" + )); + + map.put(EInterestCategory.DATA_ENGINEERING, Arrays.asList( + "데이터", "분석", "Spark", "Kafka", "파이프라인", "ETL" + )); + + map.put(EInterestCategory.DATA_SCIENCE, Arrays.asList( + "데이터", "분석", "ML", "통계", "시각화", "예측" + )); + + map.put(EInterestCategory.DATABASE, Arrays.asList( + "SQL", "Database", "MySQL", "PostgreSQL", "DB", "쿼리", "인덱스" + )); + + map.put(EInterestCategory.DEVOPS, Arrays.asList( + "DevOps", "Docker", "Kubernetes", "CI/CD", "배포", "인프라" + )); + + map.put(EInterestCategory.CLOUD, Arrays.asList( + "AWS", "클라우드", "Cloud", "Azure", "GCP" + )); + + return map; + } + + /** + * 관심사와 관련된 게시글 찾기 (제목 기반 키워드 매칭) + */ + @Transactional(readOnly = true) + public List findPostsRelatedToInterests(List interests, int limit) { + List allPosts = postRepository.findAll(); + + // 관심사 키워드와 매칭되는 게시글 찾기 + List relatedPosts = allPosts.stream() + .filter(post -> matchesInterests(post, interests)) + .limit(limit) + .collect(Collectors.toList()); + + // 매칭 안 되면 랜덤으로 채우기 + if (relatedPosts.size() < limit && !allPosts.isEmpty()) { + return fillWithRandomPosts(relatedPosts, allPosts, limit); + } + + return relatedPosts; + } + + /** + * 게시글이 관심사와 매칭되는지 확인 + */ + private boolean matchesInterests(Post post, List interests) { + String title = post.getTitle().toLowerCase(); + + return interests.stream() + .flatMap(interest -> INTEREST_KEYWORDS.getOrDefault(interest, Collections.emptyList()).stream()) + .anyMatch(keyword -> title.contains(keyword.toLowerCase())); + } + + /** + * 부족한 만큼 랜덤 게시글로 채우기 + */ + private List fillWithRandomPosts(List relatedPosts, List allPosts, int limit) { + List remaining = new ArrayList<>(allPosts); + remaining.removeAll(relatedPosts); + Collections.shuffle(remaining); + + List combined = new ArrayList<>(relatedPosts); + int needed = limit - relatedPosts.size(); + int available = Math.min(needed, remaining.size()); + combined.addAll(remaining.subList(0, available)); + + return combined; + } + + /** + * 특정 관심사 카테고리의 키워드 목록 반환 + */ + public List getKeywordsForCategory(EInterestCategory category) { + return INTEREST_KEYWORDS.getOrDefault(category, Collections.emptyList()); + } +} diff --git a/src/test/java/com/techfork/domain/recommendation/setup/components/TestDataGenerator.java b/src/test/java/com/techfork/domain/recommendation/setup/components/TestDataGenerator.java new file mode 100644 index 0000000..a13982f --- /dev/null +++ b/src/test/java/com/techfork/domain/recommendation/setup/components/TestDataGenerator.java @@ -0,0 +1,236 @@ +package com.techfork.domain.recommendation.setup.components; + +import com.techfork.domain.post.entity.Post; +import com.techfork.domain.post.repository.PostRepository; +import com.techfork.domain.user.document.UserProfileDocument; +import com.techfork.domain.user.entity.User; +import com.techfork.domain.user.enums.EInterestCategory; +import com.techfork.domain.user.repository.UserProfileDocumentRepository; +import com.techfork.domain.user.service.UserProfileService; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Tag; +import org.springframework.stereotype.Component; +import org.springframework.transaction.annotation.Transactional; + +import java.util.*; + +/** + * 추천 시스템 테스트를 위한 데이터 생성기 + * DB의 실제 게시글 데이터를 기반으로 테스트용 사용자 프로필과 Ground Truth 생성 + */ +@Tag("evaluation-setup") +@Disabled("데이터 셋업용 - CI 제외") +@Slf4j +@Component +@RequiredArgsConstructor +public class TestDataGenerator { + + public record UserCreationResult( + User user, + Map groundTruthScores + ) {} + + private final PostRepository postRepository; + private final UserProfileService userProfileService; + private final UserProfileDocumentRepository userProfileDocumentRepository; + private final org.springframework.data.elasticsearch.core.ElasticsearchOperations elasticsearchOperations; + + // 분리된 컴포넌트 + private final PostMatcher postMatcher; + private final UserTestDataBuilder userTestDataBuilder; + private final GroundTruthGenerator groundTruthGenerator; + private final GroundTruthValidator groundTruthValidator; + + // 공통 읽은 글 풀 (모든 사용자가 공통으로 읽는 글) + private static List sharedReadPosts = null; + + /** + * 테스트용 사용자 생성 - Leave-K-Out 방식 + * + * @param interestCategories 관심사 카테고리 목록 + * @param readPostCount 읽은 글 개수 (프로필 구성용) + * @param holdoutCount 숨길 정답 개수 (Ground Truth) + * @return 사용자 생성 결과 (User + Ground Truth 게시글 ID) + */ + public UserCreationResult createTestUserWithGroundTruth(List interestCategories, int readPostCount, int holdoutCount) { + return createTestUserWithGroundTruth(interestCategories, readPostCount, holdoutCount, 30); // 기본 30% 공통 + } + + /** + * 테스트용 사용자 생성 - Leave-K-Out 방식 (공통 읽은 글 비율 지정 가능) + * + * @param interestCategories 관심사 카테고리 목록 + * @param readPostCount 읽은 글 개수 (프로필 구성용) + * @param holdoutCount 숨길 정답 개수 (Ground Truth) + * @param sharedPostPercentage 공통 읽은 글 비율 (0~100, 기본 30) + * @return 사용자 생성 결과 (User + Ground Truth 게시글 ID) + */ + public UserCreationResult createTestUserWithGroundTruth( + List interestCategories, + int readPostCount, + int holdoutCount, + int sharedPostPercentage) { + + // 1. 사용자 생성 (관심사 포함) + User user = userTestDataBuilder.createUserWithInterests(interestCategories); + + // 공통 읽은 글 풀 초기화 (첫 사용자 생성 시) + int sharedPostCount = (int) (readPostCount * sharedPostPercentage / 100.0); + int personalPostCount = readPostCount - sharedPostCount; + + if (sharedReadPosts == null) { + // 전체 게시글 중 랜덤하게 공통 읽은 글 선택 + List allPosts = new ArrayList<>(postRepository.findAll()); + Collections.shuffle(allPosts); + sharedReadPosts = allPosts.subList(0, Math.min(sharedPostCount, allPosts.size())); + log.info("공통 읽은 글 풀 초기화: {} 개 (전체 {}개 중 {}%)", + sharedReadPosts.size(), readPostCount, sharedPostPercentage); + } + + // 2. Leave-K-Out: 관심 있는 글 중 일부를 숨겨서 Ground Truth로 사용 + int totalRelatedPosts = personalPostCount + holdoutCount; + List relatedPosts = postMatcher.findPostsRelatedToInterests(interestCategories, totalRelatedPosts); + Collections.shuffle(relatedPosts); + + if (relatedPosts.size() < totalRelatedPosts) { + log.warn("관심 있는 글이 부족합니다. 요청: {}, 실제: {}", totalRelatedPosts, relatedPosts.size()); + } + + // 읽은 글 = 공통 글 + 개인화 글 + List readPosts = new ArrayList<>(); + readPosts.addAll(sharedReadPosts); + + int actualPersonalCount = Math.min(relatedPosts.size(), personalPostCount); + readPosts.addAll(relatedPosts.subList(0, actualPersonalCount)); + + // 3. 읽은 글 저장 (프로필 구성용) + userTestDataBuilder.createReadPosts(user, readPosts); + log.info("읽은 글 구성: 공통 {} 개 + 개인화 {} 개 = 총 {} 개", + sharedReadPosts.size(), actualPersonalCount, readPosts.size()); + + // 4. 스크랩한 글 생성 (읽은 글 중 25% 정도를 스크랩) + int scrapCount = Math.max(5, readPosts.size() / 4); // 최소 5개 + userTestDataBuilder.createScrapPosts(user, readPosts, scrapCount); + log.info("스크랩한 글: {} 개 생성", scrapCount); + + // 5. 검색 기록 생성 (관심사 키워드 기반) + List searchKeywords = generateSearchKeywords(interestCategories); + int searchHistoryCount = Math.min(30, searchKeywords.size() * 2); // 최대 30개 + userTestDataBuilder.createSearchHistories(user, searchKeywords, searchHistoryCount); + log.info("검색 기록: {} 개 생성", searchHistoryCount); + + // UserProfile 생성 (임베딩 포함) - 동기 버전 사용 + // Ground Truth 점수 계산 전에 프로필 벡터가 필요함 + UserProfileDocument userProfile = null; + try { + userProfileService.generateUserProfileSync(user.getId()); + + // Elasticsearch Refresh: 저장이 검색 가능해지도록 강제 갱신 + elasticsearchOperations.indexOps(UserProfileDocument.class).refresh(); + + Optional userProfileOpt = userProfileDocumentRepository.findByUserId(user.getId()); + if (userProfileOpt.isPresent()) { + userProfile = userProfileOpt.get(); + log.info("사용자 프로필 및 임베딩 생성 완료: userId={}", user.getId()); + } else { + log.error("사용자 프로필을 찾을 수 없습니다 (Refresh 후에도 없음): userId={}", user.getId()); + } + } catch (Exception e) { + log.error("사용자 프로필 생성 실패: userId={}", user.getId(), e); + throw e; + } + + if (userProfile == null) { + throw new RuntimeException("UserProfile 생성 실패. Ground Truth를 계산할 수 없습니다."); + } + + // 6. Ground Truth 관련도 점수 계산 (LLM 기반) + int actualHoldoutCount = Math.min(relatedPosts.size() - actualPersonalCount, holdoutCount); + List holdoutPosts = relatedPosts.subList(actualPersonalCount, actualPersonalCount + actualHoldoutCount); + + Map groundTruthScores = groundTruthGenerator.calculateGroundTruth( + holdoutPosts, + userProfile + ); + + log.info("Ground Truth 설정: {} 개 (평가용, 숨김)", actualHoldoutCount); + + Map scoreDistribution = groundTruthScores.values().stream() + .collect(java.util.stream.Collectors.groupingBy(score -> score, java.util.stream.Collectors.counting())); + log.info("Ground Truth 점수 분포: {}", scoreDistribution); + + groundTruthValidator.validateGroundTruthQuality(groundTruthScores, interestCategories); + + return new UserCreationResult(user, groundTruthScores); + } + + private List generateSearchKeywords(List interests) { + Map> keywordMap = new HashMap<>(); + + // BACKEND 관련 검색어 + keywordMap.put(EInterestCategory.BACKEND, Arrays.asList( + "Spring Boot", "Java", "Kotlin", "API 설계", "서버 아키텍처", + "마이크로서비스", "REST API", "데이터베이스 최적화", "JPA", "Hibernate" + )); + + // FRONTEND 관련 검색어 + keywordMap.put(EInterestCategory.FRONTEND, Arrays.asList( + "React", "Vue.js", "JavaScript", "TypeScript", "CSS", + "UI/UX", "웹 성능 최적화", "반응형 디자인", "Next.js", "Webpack" + )); + + // AI/ML 관련 검색어 + keywordMap.put(EInterestCategory.AI_ML, Arrays.asList( + "딥러닝", "머신러닝", "LLM", "ChatGPT", "AI 모델", + "TensorFlow", "PyTorch", "자연어처리", "컴퓨터 비전", "데이터 분석" + )); + + // DATABASE 관련 검색어 + keywordMap.put(EInterestCategory.DATABASE, Arrays.asList( + "MySQL", "PostgreSQL", "Redis", "MongoDB", "쿼리 최적화", + "인덱싱", "데이터베이스 설계", "SQL", "NoSQL", "데이터 모델링" + )); + + // DATA_SCIENCE 관련 검색어 + keywordMap.put(EInterestCategory.DATA_SCIENCE, Arrays.asList( + "데이터 분석", "통계", "Python", "데이터 시각화", "Pandas", + "빅데이터", "데이터 파이프라인", "ETL", "데이터 엔지니어링", "Jupyter" + )); + + // DEVOPS 관련 검색어 + keywordMap.put(EInterestCategory.DEVOPS, Arrays.asList( + "Docker", "Kubernetes", "CI/CD", "AWS", "클라우드", + "인프라", "모니터링", "로깅", "배포 자동화", "Terraform" + )); + + // MOBILE 관련 검색어 + keywordMap.put(EInterestCategory.ANDROID, Arrays.asList( + "Android", "React Native", "Flutter", "모바일 앱", + "Kotlin", "크로스플랫폼", "앱 성능", "모바일 UI" + )); + + // SECURITY 관련 검색어 + keywordMap.put(EInterestCategory.SECURITY, Arrays.asList( + "보안", "인증", "암호화", "OAuth", "JWT", + "해킹 방어", "보안 취약점", "HTTPS", "방화벽", "보안 아키텍처" + )); + + // 사용자의 관심사에 해당하는 키워드들을 모두 모음 + List allKeywords = new ArrayList<>(); + for (EInterestCategory interest : interests) { + List keywords = keywordMap.getOrDefault(interest, Collections.emptyList()); + allKeywords.addAll(keywords); + } + + // 키워드가 없으면 기본 키워드 사용 + if (allKeywords.isEmpty()) { + allKeywords.addAll(Arrays.asList( + "개발", "프로그래밍", "코딩", "소프트웨어", "기술 블로그" + )); + } + + return allKeywords; + } +} diff --git a/src/test/java/com/techfork/domain/recommendation/setup/components/UserTestDataBuilder.java b/src/test/java/com/techfork/domain/recommendation/setup/components/UserTestDataBuilder.java new file mode 100644 index 0000000..e014771 --- /dev/null +++ b/src/test/java/com/techfork/domain/recommendation/setup/components/UserTestDataBuilder.java @@ -0,0 +1,138 @@ +package com.techfork.domain.recommendation.setup.components; + +import com.techfork.domain.activity.entity.ReadPost; +import com.techfork.domain.activity.entity.ScrabPost; +import com.techfork.domain.activity.entity.SearchHistory; +import com.techfork.domain.activity.repository.ReadPostRepository; +import com.techfork.domain.activity.repository.ScrabPostRepository; +import com.techfork.domain.activity.repository.SearchHistoryRepository; +import com.techfork.domain.post.entity.Post; +import com.techfork.domain.user.entity.User; +import com.techfork.domain.user.entity.UserInterestCategory; +import com.techfork.domain.user.entity.UserInterestKeyword; +import com.techfork.domain.user.enums.EInterestCategory; +import com.techfork.domain.user.enums.EInterestKeyword; +import com.techfork.domain.user.enums.SocialType; +import com.techfork.domain.user.repository.UserInterestCategoryRepository; +import com.techfork.domain.user.repository.UserRepository; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.springframework.stereotype.Component; + +import java.time.LocalDateTime; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.UUID; + +@Slf4j +@Component +@RequiredArgsConstructor +public class UserTestDataBuilder { + + private final UserRepository userRepository; + private final UserInterestCategoryRepository userInterestCategoryRepository; + private final ReadPostRepository readPostRepository; + private final ScrabPostRepository scrabPostRepository; + private final SearchHistoryRepository searchHistoryRepository; + + public User createUserWithInterests(List interestCategories) { + User user = User.createSocialUser( + SocialType.KAKAO, + "testSocialId_" + UUID.randomUUID(), + "test_" + System.currentTimeMillis() + "@example.com", + null + ); + user = userRepository.save(user); + + log.info("테스트 사용자 생성: ID: {}", user.getId()); + + // 관심사 카테고리 및 키워드 추가 + for (EInterestCategory category : interestCategories) { + UserInterestCategory interestCategory = UserInterestCategory.create(user, category); + userInterestCategoryRepository.save(interestCategory); + + // 해당 카테고리의 키워드 중 랜덤하게 2~4개 선택 + List availableKeywords = new ArrayList<>( + EInterestKeyword.getKeywordsByCategory(category) + ); + Collections.shuffle(availableKeywords); + int keywordCount = 2 + (int) (Math.random() * 3); // 2~4개 + + for (int i = 0; i < Math.min(keywordCount, availableKeywords.size()); i++) { + UserInterestKeyword keyword = UserInterestKeyword.create( + interestCategory, + availableKeywords.get(i) + ); + interestCategory.addKeyword(keyword); + } + + userInterestCategoryRepository.save(interestCategory); + } + + log.info("관심사 추가: {} (각 카테고리별 키워드 포함)", interestCategories); + + return user; + } + + public void createReadPosts(User user, List posts) { + LocalDateTime now = LocalDateTime.now(); + List readPosts = new ArrayList<>(); + + for (int i = 0; i < posts.size(); i++) { + Post post = posts.get(i); + ReadPost readPost = ReadPost.create( + user, + post, + now.minusDays(posts.size() - i), + 180 // 3분 읽음 + ); + readPosts.add(readPost); + } + + readPostRepository.saveAll(readPosts); + log.debug("읽은 글 {} 개 생성 완료", posts.size()); + } + + + public void createScrapPosts(User user, List readPosts, int scrapCount) { + LocalDateTime now = LocalDateTime.now(); + List scrabPosts = new ArrayList<>(); + + List postsToScrap = new ArrayList<>(readPosts); + Collections.shuffle(postsToScrap); + + int actualScrapCount = Math.min(scrapCount, postsToScrap.size()); + + for (int i = 0; i < actualScrapCount; i++) { + Post post = postsToScrap.get(i); + ScrabPost scrabPost = ScrabPost.create( + user, + post, + now.minusDays(readPosts.size() - i - 5) // 읽은 시점보다 약간 후에 스크랩 + ); + scrabPosts.add(scrabPost); + } + + scrabPostRepository.saveAll(scrabPosts); + log.debug("스크랩한 글 {} 개 생성 완료", actualScrapCount); + } + + public void createSearchHistories(User user, List searchKeywords, int searchHistoryCount) { + LocalDateTime now = LocalDateTime.now(); + List searchHistories = new ArrayList<>(); + + for (int i = 0; i < searchHistoryCount; i++) { + String searchWord = searchKeywords.get(i % searchKeywords.size()); + SearchHistory searchHistory = SearchHistory.create( + user, + searchWord, + now.minusDays(searchHistoryCount - i * 2) // 읽기 활동 사이사이에 검색 + ); + searchHistories.add(searchHistory); + } + + searchHistoryRepository.saveAll(searchHistories); + log.debug("검색 기록 {} 개 생성 완료", searchHistoryCount); + } +} diff --git a/src/test/java/com/techfork/domain/recommendation/util/EvaluationFixtureLoader.java b/src/test/java/com/techfork/domain/recommendation/util/EvaluationFixtureLoader.java new file mode 100644 index 0000000..9fc365c --- /dev/null +++ b/src/test/java/com/techfork/domain/recommendation/util/EvaluationFixtureLoader.java @@ -0,0 +1,454 @@ +package com.techfork.domain.recommendation.util; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule; +import com.techfork.domain.activity.entity.ReadPost; +import com.techfork.domain.activity.repository.ReadPostRepository; +import com.techfork.domain.post.document.ContentChunk; +import com.techfork.domain.post.document.PostDocument; +import com.techfork.domain.post.entity.Post; +import com.techfork.domain.post.repository.PostDocumentRepository; +import com.techfork.domain.post.repository.PostRepository; +import com.techfork.domain.source.entity.TechBlog; +import com.techfork.domain.source.repository.TechBlogRepository; +import com.techfork.domain.user.document.UserProfileDocument; +import com.techfork.domain.user.entity.User; +import com.techfork.domain.user.entity.UserInterestCategory; +import com.techfork.domain.user.entity.UserInterestKeyword; +import com.techfork.domain.user.enums.EInterestCategory; +import com.techfork.domain.user.enums.EInterestKeyword; +import com.techfork.domain.user.enums.Role; +import com.techfork.domain.user.enums.SocialType; +import com.techfork.domain.user.repository.UserProfileDocumentRepository; +import com.techfork.domain.user.repository.UserRepository; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.springframework.core.io.ClassPathResource; +import org.springframework.stereotype.Component; +import org.springframework.transaction.annotation.Transactional; + +import java.io.IOException; +import java.time.LocalDateTime; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +/** + * JSON fixture를 읽어서 Testcontainer에 로드하는 클래스 + */ +@Slf4j +@Component +@RequiredArgsConstructor +public class EvaluationFixtureLoader { + + private final UserRepository userRepository; + private final PostRepository postRepository; + private final ReadPostRepository readPostRepository; + private final PostDocumentRepository postDocumentRepository; + private final UserProfileDocumentRepository userProfileDocumentRepository; + private final TechBlogRepository techBlogRepository; + + private final ObjectMapper objectMapper = new ObjectMapper() + .registerModule(new JavaTimeModule()); + + private static final String FIXTURE_PATH = "fixtures/evaluation/"; + + @Transactional + public Map> loadAll() { + log.info("===== Fixture 로드 시작 ====="); + + try { + Map userMap = loadUsers(); + log.info("✓ 사용자 {} 명 로드 완료", userMap.size()); + + Map postMap = loadPosts(); + log.info("✓ 게시글 {} 개 로드 완료", postMap.size()); + + int readPostCount = loadReadPosts(userMap, postMap); + log.info("✓ 읽은 글 이력 {} 개 로드 완료", readPostCount); + + int postDocCount = loadPostDocuments(postMap); + log.info("✓ PostDocument {} 개 로드 완료 (임베딩 포함)", postDocCount); + + int userProfileCount = loadUserProfiles(userMap); + log.info("✓ UserProfileDocument {} 개 로드 완료 (임베딩 포함)", userProfileCount); + + Map> groundTruth = loadGroundTruth(userMap, postMap); + log.info("✓ Ground Truth {} 명 사용자 로드 완료", groundTruth.size()); + + log.info("===== Fixture 로드 완료 =====\n"); + + return groundTruth; + + } catch (IOException e) { + log.error("Fixture 로드 실패", e); + throw new RuntimeException("Fixture 로드 중 오류 발생", e); + } + } + + @Transactional + public Map loadPostsOnly() { + log.info("===== 게시글 Fixture 로드 시작 ====="); + + try { + Map postMap = loadPosts(); + log.info("✓ 게시글 {} 개 로드 완료", postMap.size()); + + int postDocCount = loadPostDocuments(postMap); + log.info("✓ PostDocument {} 개 로드 완료 (임베딩 포함)", postDocCount); + + log.info("===== 게시글 Fixture 로드 완료 =====\n"); + return postMap; + + } catch (IOException e) { + log.error("게시글 Fixture 로드 실패", e); + throw new RuntimeException("게시글 Fixture 로드 중 오류 발생", e); + } + } + + private Map loadUsers() throws IOException { + List> userDtos = readJsonFile("users.json", new TypeReference<>() { + }); + + Map userMap = new HashMap<>(); + + for (Map dto : userDtos) { + Long originalUserId = ((Number) dto.get("id")).longValue(); // JSON의 원래 ID + String email = (String) dto.get("email"); + String nickname = (String) dto.get("nickname"); + String profileImageUrl = (String) dto.get("profileImageUrl"); + String socialType = (String) dto.get("socialType"); + String socialId = (String) dto.get("socialId"); + String roleStr = (String) dto.get("role"); + + User user = User.builder() + .email(email) + .nickName(nickname) + .profileImage(profileImageUrl) + .socialType(SocialType.valueOf(socialType)) + .socialId(socialId) + .role(Role.valueOf(roleStr)) + .build(); + + user = userRepository.save(user); + + // 관심사 추가 + List> interests = + (List>) dto.get("interests"); + + if (interests != null) { + for (Map interestDto : interests) { + String categoryStr = (String) interestDto.get("category"); + EInterestCategory category = EInterestCategory.valueOf(categoryStr); + + UserInterestCategory interestCategory = + UserInterestCategory.create(user, category); + + // 키워드 추가 + List keywords = (List) interestDto.get("keywords"); + if (keywords != null) { + for (String keywordStr : keywords) { + EInterestKeyword keyword = EInterestKeyword.valueOf(keywordStr); + UserInterestKeyword interestKeyword = + UserInterestKeyword.create(interestCategory, keyword); + interestCategory.addKeyword(interestKeyword); + } + } + + user.getInterestCategories().add(interestCategory); + } + + user = userRepository.save(user); + } + + // JSON의 원래 ID를 키로 사용 (Ground-Truth, ReadPost 매핑을 위해) + userMap.put(originalUserId, user); + } + + return userMap; + } + + private Map loadPosts() throws IOException { + List> postDtos = readJsonFile("posts.json", new TypeReference<>() { + }); + + Map postMap = new HashMap<>(); + + // TechBlog ID -> TechBlog 매핑 + Map techBlogMap = new HashMap<>(); + + for (Map dto : postDtos) { + Long originalPostId = ((Number) dto.get("id")).longValue(); // JSON의 원래 ID + Long techBlogId = ((Number) dto.get("techBlogId")).longValue(); + + // TechBlog 생성 또는 조회 + TechBlog techBlog = techBlogMap.get(techBlogId); + if (techBlog == null) { + String companyName = (String) dto.get("techBlogCompanyName"); + String blogUrl = (String) dto.get("techBlogUrl"); + String rssUrl = (String) dto.get("techBlogRssUrl"); + + techBlog = TechBlog.create(companyName, blogUrl, rssUrl, null); + techBlog = techBlogRepository.save(techBlog); + techBlogMap.put(techBlogId, techBlog); + } + + // Post 생성 + String title = (String) dto.get("title"); + String url = (String) dto.get("url"); + String summary = (String) dto.get("summary"); + String shortSummary = (String) dto.get("shortSummary"); + String company = (String) dto.get("company"); + String logoUrl = (String) dto.get("logoUrl"); + String thumbnailUrl = (String) dto.get("thumbnailUrl"); + String publishedAtStr = (String) dto.get("publishedAt"); + LocalDateTime publishedAt = publishedAtStr != null ? + LocalDateTime.parse(publishedAtStr) : null; + + Post post = Post.builder() + .title(title) + .url(url) + .summary(summary) + .shortSummary(shortSummary) + .company(company) + .logoUrl(logoUrl) + .thumbnailUrl(thumbnailUrl) + .publishedAt(publishedAt) + .crawledAt(LocalDateTime.now()) + .techBlog(techBlog) + .build(); + + post = postRepository.save(post); + // JSON의 원래 ID를 키로 사용 (PostDocument 매핑을 위해) + postMap.put(originalPostId, post); + } + + return postMap; + } + + private int loadReadPosts(Map userMap, Map postMap) + throws IOException { + List> readPostDtos = readJsonFile("read-posts.json", new TypeReference<>() { + }); + + int count = 0; + + for (Map dto : readPostDtos) { + Long userId = ((Number) dto.get("userId")).longValue(); + Long postId = ((Number) dto.get("postId")).longValue(); + String readAtStr = (String) dto.get("readAt"); + Integer readDurationSeconds = ((Number) dto.get("readDurationSeconds")).intValue(); + + User user = userMap.get(userId); + Post post = postMap.get(postId); + + if (user == null || post == null) { + log.warn("ReadPost 로드 실패: userId={}, postId={}", userId, postId); + continue; + } + + LocalDateTime readAt = LocalDateTime.parse(readAtStr); + + ReadPost readPost = ReadPost.create(user, post, readAt, readDurationSeconds); + readPostRepository.save(readPost); + count++; + } + + return count; + } + + private int loadPostDocuments(Map postMap) throws IOException { + List> docDtos = readJsonFile("post-documents.json", new TypeReference<>() { + }); + + int count = 0; + + for (Map dto : docDtos) { + String id = String.valueOf(dto.get("id")); + Long originalPostId = ((Number) dto.get("postId")).longValue(); + + // JSON의 원래 Post ID를 실제 저장된 Post ID로 매핑 + Post post = postMap.get(originalPostId); + if (post == null) { + log.warn("PostDocument 로드 중 Post를 찾을 수 없음: originalPostId={}", originalPostId); + continue; + } + Long actualPostId = post.getId(); + + String title = (String) dto.get("title"); + String summary = (String) dto.get("summary"); + String shortSummary = (String) dto.get("shortSummary"); + String company = (String) dto.get("company"); + String url = (String) dto.get("url"); + String logoUrl = (String) dto.get("logoUrl"); + String thumbnailUrl = (String) dto.get("thumbnailUrl"); + String publishedAt = (String) dto.get("publishedAt"); + + // 임베딩 벡터 (List -> List) + List titleEmbedding = convertToFloatList( + (List) dto.get("titleEmbedding")); + List summaryEmbedding = convertToFloatList( + (List) dto.get("summaryEmbedding")); + + // ContentChunk (nested) + List contentChunks = null; + if (dto.get("contentChunks") != null) { + List> chunkDtos = + (List>) dto.get("contentChunks"); + contentChunks = chunkDtos.stream() + .map(this::convertToContentChunk) + .collect(Collectors.toList()); + } + + PostDocument postDoc = PostDocument.builder() + .id(id) + .postId(actualPostId) // 실제 저장된 Post ID 사용 + .title(title) + .summary(summary) + .shortSummary(shortSummary) + .company(company) + .url(url) + .logoUrl(logoUrl) + .thumbnailUrl(thumbnailUrl) + .publishedAtString(publishedAt) + .titleEmbedding(titleEmbedding) + .summaryEmbedding(summaryEmbedding) + .contentChunks(contentChunks) + .build(); + + postDocumentRepository.save(postDoc); + count++; + } + + return count; + } + + private int loadUserProfiles(Map userMap) throws IOException { + List> profileDtos = readJsonFile("user-profiles.json", new TypeReference<>() { + }); + + int count = 0; + + for (Map dto : profileDtos) { + Long originalUserId = ((Number) dto.get("userId")).longValue(); + String profileText = (String) dto.get("profileText"); + List interests = (List) dto.get("interests"); + + // JSON의 원래 User ID를 실제 DB User ID로 매핑 + User user = userMap.get(originalUserId); + if (user == null) { + log.warn("UserProfile 로드 실패: 사용자를 찾을 수 없음 (원본 User ID={})", originalUserId); + continue; + } + Long actualUserId = user.getId(); + + // 임베딩 벡터 (List -> float[]) + List vectorList = (List) dto.get("profileVector"); + float[] profileVector = null; + if (vectorList != null) { + profileVector = new float[vectorList.size()]; + for (int i = 0; i < vectorList.size(); i++) { + profileVector[i] = vectorList.get(i).floatValue(); + } + } + + UserProfileDocument profile = UserProfileDocument.builder() + .userId(actualUserId) + .profileText(profileText) + .profileVector(profileVector) + .interests(interests) + .build(); + + userProfileDocumentRepository.save(profile); + count++; + } + + return count; + } + + private ContentChunk convertToContentChunk(Map dto) { + Integer chunkOrder = dto.get("chunkOrder") != null ? + ((Number) dto.get("chunkOrder")).intValue() : null; + String chunkText = (String) dto.get("chunkText"); + List embedding = convertToFloatList((List) dto.get("embedding")); + + return ContentChunk.builder() + .chunkOrder(chunkOrder) + .chunkText(chunkText) + .embedding(embedding) + .build(); + } + + private List convertToFloatList(List numbers) { + if (numbers == null) return null; + return numbers.stream() + .map(Number::floatValue) + .collect(Collectors.toList()); + } + + /** + * Ground Truth 데이터 로드 (Post ID를 실제 DB ID로 매핑) + * JSON 구조: { "userId": { "postId": relevanceScore, ... }, ... } + * + * @param userMap JSON의 원래 User ID -> 실제 저장된 User 매핑 + * @param postMap JSON의 원래 Post ID -> 실제 저장된 Post 매핑 + * @return Map<실제 사용자 DB ID, Map<실제 게시글 DB ID, 관련도점수>> + */ + private Map> loadGroundTruth( + Map userMap, + Map postMap) throws IOException { + // JSON에서 String 키로 읽어서 Long으로 변환 + Map> rawData = readJsonFile("ground-truth.json", new TypeReference<>() { + }); + + Map> groundTruth = new HashMap<>(); + int mappedCount = 0; + int skippedCount = 0; + + for (Map.Entry> userEntry : rawData.entrySet()) { + Long originalUserId = Long.parseLong(userEntry.getKey()); + + // JSON의 User ID -> 실제 DB User ID 매핑 + User user = userMap.get(originalUserId); + if (user == null) { + log.warn("Ground Truth 로드 실패: 사용자를 찾을 수 없음 (원본 User ID={})", originalUserId); + skippedCount++; + continue; + } + Long actualUserId = user.getId(); + + Map postScores = new HashMap<>(); + + for (Map.Entry postEntry : userEntry.getValue().entrySet()) { + Long originalPostId = Long.parseLong(postEntry.getKey()); + Integer relevanceScore = postEntry.getValue(); + + // JSON의 Post ID -> 실제 DB Post ID 매핑 + Post post = postMap.get(originalPostId); + if (post == null) { + log.debug("Ground Truth 매핑 실패: Post를 찾을 수 없음 (원본 Post ID={})", originalPostId); + skippedCount++; + continue; + } + Long actualPostId = post.getId(); + + postScores.put(actualPostId, relevanceScore); + mappedCount++; + } + + if (!postScores.isEmpty()) { + groundTruth.put(actualUserId, postScores); + } + } + + log.info("✓ Ground Truth 매핑 완료: {} 개 (스킵: {} 개)", mappedCount, skippedCount); + return groundTruth; + } + + private T readJsonFile(String filename, TypeReference typeRef) throws IOException { + ClassPathResource resource = new ClassPathResource(FIXTURE_PATH + filename); + return objectMapper.readValue(resource.getInputStream(), typeRef); + } +} \ No newline at end of file diff --git a/src/test/resources/application-integrationtest.yml b/src/test/resources/application-integrationtest.yml index 96f2d1f..c2e3c7c 100644 --- a/src/test/resources/application-integrationtest.yml +++ b/src/test/resources/application-integrationtest.yml @@ -94,12 +94,12 @@ recommendation: knn-search-size: 100 num-candidates: 200 mmr-final-size: 30 - lambda: 0.3 + lambda: 0.95 active-user-hours: 24 embedding-weights: - title: 0.2 - summary: 0.2 - content: 0.6 + title: 0.5 + summary: 0.5 + content: 0.0 time-decay: days-7: 1.3 days-30: 1.0