diff --git a/src/main/java/com/sofa/linkiving/domain/chat/ai/AnswerClient.java b/src/main/java/com/sofa/linkiving/domain/chat/ai/AnswerClient.java new file mode 100644 index 00000000..b0d3768e --- /dev/null +++ b/src/main/java/com/sofa/linkiving/domain/chat/ai/AnswerClient.java @@ -0,0 +1,8 @@ +package com.sofa.linkiving.domain.chat.ai; + +import com.sofa.linkiving.domain.chat.dto.request.RagAnswerReq; +import com.sofa.linkiving.domain.chat.dto.response.RagAnswerRes; + +public interface AnswerClient { + RagAnswerRes generateAnswer(RagAnswerReq request); +} diff --git a/src/main/java/com/sofa/linkiving/domain/chat/ai/MockAnswerClient.java b/src/main/java/com/sofa/linkiving/domain/chat/ai/MockAnswerClient.java new file mode 100644 index 00000000..94bfd50a --- /dev/null +++ b/src/main/java/com/sofa/linkiving/domain/chat/ai/MockAnswerClient.java @@ -0,0 +1,36 @@ +package com.sofa.linkiving.domain.chat.ai; + +import java.util.List; + +import org.springframework.context.annotation.Primary; +import org.springframework.stereotype.Component; + +import com.sofa.linkiving.domain.chat.dto.request.RagAnswerReq; +import com.sofa.linkiving.domain.chat.dto.response.RagAnswerRes; + +import lombok.extern.slf4j.Slf4j; + +@Slf4j +@Component +@Primary +public class MockAnswerClient implements AnswerClient { + + @Override + public RagAnswerRes generateAnswer(RagAnswerReq request) { + log.info("[Mock AI Request] User: {}, Question: {}, Mode: {}, HistoryCnt: {}", + request.userId(), request.question(), request.mode(), request.history().size()); + + return new RagAnswerRes( + "임시 답변", + List.of("3", "4"), + List.of( + new RagAnswerRes.ReasoningStep( + "임시 답변 스탭", + List.of("3", "4") + ) + ), + List.of("3", "4"), + false + ); + } +} diff --git a/src/main/java/com/sofa/linkiving/domain/chat/controller/ChatApi.java b/src/main/java/com/sofa/linkiving/domain/chat/controller/ChatApi.java index 1ae07939..7d58ed92 100644 --- a/src/main/java/com/sofa/linkiving/domain/chat/controller/ChatApi.java +++ b/src/main/java/com/sofa/linkiving/domain/chat/controller/ChatApi.java @@ -2,6 +2,8 @@ import org.springframework.validation.annotation.Validated; +import com.sofa.linkiving.domain.chat.dto.request.AnswerCancelReq; +import com.sofa.linkiving.domain.chat.dto.request.AnswerReq; import com.sofa.linkiving.domain.chat.dto.request.CreateChatReq; import com.sofa.linkiving.domain.chat.dto.response.ChatsRes; import com.sofa.linkiving.domain.chat.dto.response.CreateChatRes; @@ -20,18 +22,45 @@ @Tag(name = "Chat", description = """ AI 채팅 통합 명세 (HTTP + WebSocket) - ### 📡 1. WebSocket 연결 정보 (필수) - 답변을 실시간으로 수신하기 위해 **반드시 소켓 연결 및 구독**이 선행되어야 합니다. - + ### 📡 1. WebSocket 연결 정보 * **Socket Endpoint:** `ws://{domain}/ws/chat` - * **Subscribe Path:** `/topic/chat/{chatId}` - * **Auth Header:** `Authorization: Bearer {accessToken}` (CONNECT 프레임 헤더) + * **Subscribe Path:** `/user/queue/chat` (전역 구독) + ### 🚀 2. 동작 흐름 - 1. **소켓 연결:** 프론트엔드에서 WebSocket 연결 및 `/topic/chat/{chatId}` 구독 - 2. **질문 전송:** `/app/send/{chatId}` (STOMP)로 질문 전송 - 3. **답변 수신:** 소켓 구독 채널로 토큰 단위 답변 스트리밍 (`String` 데이터) - 4. **완료:** `END_OF_STREAM` 메시지 수신 시 스트리밍 종료 + 1. **소켓 연결:** 로그인 직후 `/user/queue/chat` 구독 + 2. **질문 전송:** `/send` 로 요청 전송 + - Body: `{ "chatId": 1, "message": "질문" }` + 3. **답변 수신:** 구독한 경로로 답변 도착 (chatId 포함됨) + **CASE A: 답변 생성 성공 (success: true)** + - AI의 답변과 참고 링크가 포함됩니다. + ```json + { + "success": true, + "chatId": 1, + "messageId": 105, + "content": "질문하신 내용에 대한 AI 답변입니다...", + "step": ["질문 분석", "데이터 검색", "답변 생성"], + "links": [ + { "linkId": 10, "title": "관련 문서 제목", "url": "https://...", "imageUrl": "http://...", "summary": "요약 내용" } + ] + } + ``` + **CASE B: 답변 생성 실패 (success: false)** + - 에러 상황입니다. `content` 필드에 **사용자가 보냈던 원래 질문**이 담겨옵니다. + - 프론트엔드 처리: 이 값을 다시 입력창(Input)에 채워주세요. + ```json + { + "success": false, + "chatId": 1, + "messageId": null, + "content": "내 질문 내용", + "step": null, + "links": null + } + ``` + 4. **답변 취소**: `/cancel` 로 요청 전송 + - Body: `{ "chatId": 1 }` """) public interface ChatApi { @Operation(summary = "채팅 기록 조회", description = "채팅 기록을 최신순으로 조회합니다. 무한 스크롤 방식으로 제공됩니다.") @@ -58,8 +87,7 @@ BaseResponse createChat( @Operation(summary = "링크 삭제", description = "해당 링크방과 채팅 기록을 전부 Hard Delete 진행합니다.") BaseResponse deleteChat(Member member, Long chatId); - void sendMessage(@Parameter(description = "채팅방 Id", required = true) Long chatId, - @Parameter(description = "사용자 질문 내용", required = true) String message, Member member); + void sendMessage(AnswerReq req, Member member); - void cancelMessage(@Parameter(description = "채팅방 Id", required = true) Long chatId, Member member); + void cancelMessage(AnswerCancelReq req, Member member); } diff --git a/src/main/java/com/sofa/linkiving/domain/chat/controller/ChatController.java b/src/main/java/com/sofa/linkiving/domain/chat/controller/ChatController.java index b3ccb0b1..9ebf6f6d 100644 --- a/src/main/java/com/sofa/linkiving/domain/chat/controller/ChatController.java +++ b/src/main/java/com/sofa/linkiving/domain/chat/controller/ChatController.java @@ -1,6 +1,5 @@ package com.sofa.linkiving.domain.chat.controller; -import org.springframework.messaging.handler.annotation.DestinationVariable; import org.springframework.messaging.handler.annotation.MessageMapping; import org.springframework.messaging.handler.annotation.Payload; import org.springframework.web.bind.annotation.DeleteMapping; @@ -12,6 +11,8 @@ import org.springframework.web.bind.annotation.RequestParam; import org.springframework.web.bind.annotation.RestController; +import com.sofa.linkiving.domain.chat.dto.request.AnswerCancelReq; +import com.sofa.linkiving.domain.chat.dto.request.AnswerReq; import com.sofa.linkiving.domain.chat.dto.request.CreateChatReq; import com.sofa.linkiving.domain.chat.dto.response.ChatsRes; import com.sofa.linkiving.domain.chat.dto.response.CreateChatRes; @@ -51,15 +52,15 @@ public BaseResponse deleteChat(@AuthMember Member member, @PathVariable } @Override - @MessageMapping("/send/{chatId}") - public void sendMessage(@DestinationVariable Long chatId, @Payload String message, @AuthMember Member member) { - chatFacade.generateAnswer(chatId, member, message); + @MessageMapping("/send") + public void sendMessage(@Payload AnswerReq req, @AuthMember Member member) { + chatFacade.generateAnswer(req.chatId(), member, req.message()); } @Override - @MessageMapping("/cancel/{chatId}") - public void cancelMessage(@DestinationVariable Long chatId, @AuthMember Member member) { - chatFacade.cancelAnswer(chatId, member); + @MessageMapping("/cancel") + public void cancelMessage(@Payload AnswerCancelReq req, @AuthMember Member member) { + chatFacade.cancelAnswer(req.chatId(), member); } @Override diff --git a/src/main/java/com/sofa/linkiving/domain/chat/controller/MockAiController.java b/src/main/java/com/sofa/linkiving/domain/chat/controller/MockAiController.java deleted file mode 100644 index c4771a60..00000000 --- a/src/main/java/com/sofa/linkiving/domain/chat/controller/MockAiController.java +++ /dev/null @@ -1,33 +0,0 @@ -package com.sofa.linkiving.domain.chat.controller; - -import java.time.Duration; -import java.util.Map; - -import org.springframework.http.MediaType; -import org.springframework.web.bind.annotation.PostMapping; -import org.springframework.web.bind.annotation.RequestBody; -import org.springframework.web.bind.annotation.RequestMapping; -import org.springframework.web.bind.annotation.RestController; - -import reactor.core.publisher.Flux; - -@RestController -@RequestMapping("/mock/ai") -public class MockAiController { - - @PostMapping(value = "/generate", produces = MediaType.APPLICATION_NDJSON_VALUE) // 또는 TEXT_EVENT_STREAM_VALUE - public Flux generateAnswer(@RequestBody Map request) { - String userPrompt = request.get("prompt"); - - String fakeResponse = """ - 안녕하세요! 저는 임시 AI 봇입니다. 🤖 - 현재 AI 서버가 구축되지 않아서 테스트용 답변을 드리고 있어요. - 질문하신 내용인 "%s"에 대해 답변을 생성하는 척 하고 있습니다. - 취소 기능을 테스트하시려면 지금 바로 취소 버튼을 눌러보세요! - 타이핑 효과를 위해 천천히 답변을 보내고 있습니다... - """.formatted(userPrompt); - - return Flux.fromArray(fakeResponse.split("")) - .delayElements(Duration.ofMillis(100)); - } -} diff --git a/src/main/java/com/sofa/linkiving/domain/chat/dto/request/AnswerCancelReq.java b/src/main/java/com/sofa/linkiving/domain/chat/dto/request/AnswerCancelReq.java new file mode 100644 index 00000000..0464c5cd --- /dev/null +++ b/src/main/java/com/sofa/linkiving/domain/chat/dto/request/AnswerCancelReq.java @@ -0,0 +1,9 @@ +package com.sofa.linkiving.domain.chat.dto.request; + +import io.swagger.v3.oas.annotations.media.Schema; + +public record AnswerCancelReq( + @Schema(description = "채팅방 ID") + Long chatId +) { +} diff --git a/src/main/java/com/sofa/linkiving/domain/chat/dto/request/AnswerReq.java b/src/main/java/com/sofa/linkiving/domain/chat/dto/request/AnswerReq.java new file mode 100644 index 00000000..4eb54b42 --- /dev/null +++ b/src/main/java/com/sofa/linkiving/domain/chat/dto/request/AnswerReq.java @@ -0,0 +1,11 @@ +package com.sofa.linkiving.domain.chat.dto.request; + +import io.swagger.v3.oas.annotations.media.Schema; + +public record AnswerReq( + @Schema(description = "채팅방 ID") + Long chatId, + @Schema(description = "유저 질문 내용") + String message +) { +} diff --git a/src/main/java/com/sofa/linkiving/domain/chat/dto/request/RagAnswerReq.java b/src/main/java/com/sofa/linkiving/domain/chat/dto/request/RagAnswerReq.java new file mode 100644 index 00000000..110830bf --- /dev/null +++ b/src/main/java/com/sofa/linkiving/domain/chat/dto/request/RagAnswerReq.java @@ -0,0 +1,13 @@ +package com.sofa.linkiving.domain.chat.dto.request; + +import java.util.List; + +import com.sofa.linkiving.domain.chat.enums.Mode; + +public record RagAnswerReq( + Long userId, + String question, + List history, + Mode mode +) { +} diff --git a/src/main/java/com/sofa/linkiving/domain/chat/dto/response/AnswerRes.java b/src/main/java/com/sofa/linkiving/domain/chat/dto/response/AnswerRes.java new file mode 100644 index 00000000..571717d0 --- /dev/null +++ b/src/main/java/com/sofa/linkiving/domain/chat/dto/response/AnswerRes.java @@ -0,0 +1,46 @@ +package com.sofa.linkiving.domain.chat.dto.response; + +import java.util.List; + +import com.sofa.linkiving.domain.chat.entity.Message; +import com.sofa.linkiving.domain.link.dto.internal.LinkDto; +import com.sofa.linkiving.domain.link.dto.response.LinkCardRes; + +import io.swagger.v3.oas.annotations.media.Schema; + +public record AnswerRes( + @Schema(description = "성공 여부") + Boolean success, + @Schema(description = "채팅방 ID") + Long chatId, + @Schema(description = "메세지 ID") + Long messageId, + @Schema(description = "답변 내용") + String content, + @Schema(description = "스텝 목록") + List step, + @Schema(description = "첨부된 링크 목록") + List links +) { + public static AnswerRes of(Long chatId, Message message, List step, List linkDtos) { + return new AnswerRes( + true, + chatId, + message.getId(), + message.getContent(), + step, + linkDtos.stream().map(LinkCardRes::from).toList() + ); + } + + public static AnswerRes error(Long chatId, String content) { + return new AnswerRes( + false, + chatId, + null, + content, + null, + null + ); + } +} diff --git a/src/main/java/com/sofa/linkiving/domain/chat/dto/response/RagAnswerRes.java b/src/main/java/com/sofa/linkiving/domain/chat/dto/response/RagAnswerRes.java new file mode 100644 index 00000000..9a508adb --- /dev/null +++ b/src/main/java/com/sofa/linkiving/domain/chat/dto/response/RagAnswerRes.java @@ -0,0 +1,17 @@ +package com.sofa.linkiving.domain.chat.dto.response; + +import java.util.List; + +public record RagAnswerRes( + String answer, + List linkIds, + List reasoningSteps, + List relatedLinks, + boolean isFallback +) { + public record ReasoningStep( + String step, + List linkIds + ) { + } +} diff --git a/src/main/java/com/sofa/linkiving/domain/chat/enums/Mode.java b/src/main/java/com/sofa/linkiving/domain/chat/enums/Mode.java new file mode 100644 index 00000000..0cb24512 --- /dev/null +++ b/src/main/java/com/sofa/linkiving/domain/chat/enums/Mode.java @@ -0,0 +1,31 @@ +package com.sofa.linkiving.domain.chat.enums; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonValue; + +import lombok.Getter; +import lombok.RequiredArgsConstructor; + +@Getter +@RequiredArgsConstructor +public enum Mode { + DETAILED("detailed"), + CONCISE("concise"); + + private final String value; + + @JsonCreator + public static Mode from(String value) { + for (Mode mode : Mode.values()) { + if (mode.getValue().equalsIgnoreCase(value)) { + return mode; + } + } + return DETAILED; + } + + @JsonValue + public String getValue() { + return value; + } +} diff --git a/src/main/java/com/sofa/linkiving/domain/chat/facade/ChatFacade.java b/src/main/java/com/sofa/linkiving/domain/chat/facade/ChatFacade.java index d80bccdb..1f87668d 100644 --- a/src/main/java/com/sofa/linkiving/domain/chat/facade/ChatFacade.java +++ b/src/main/java/com/sofa/linkiving/domain/chat/facade/ChatFacade.java @@ -1,31 +1,41 @@ package com.sofa.linkiving.domain.chat.facade; import java.util.List; +import java.util.concurrent.CompletableFuture; +import org.springframework.messaging.simp.SimpMessagingTemplate; import org.springframework.stereotype.Service; import org.springframework.transaction.annotation.Transactional; import com.sofa.linkiving.domain.chat.ai.AiTitleClient; import com.sofa.linkiving.domain.chat.dto.internal.MessagesDto; +import com.sofa.linkiving.domain.chat.dto.response.AnswerRes; import com.sofa.linkiving.domain.chat.dto.response.ChatsRes; import com.sofa.linkiving.domain.chat.dto.response.CreateChatRes; import com.sofa.linkiving.domain.chat.dto.response.MessagesRes; import com.sofa.linkiving.domain.chat.entity.Chat; +import com.sofa.linkiving.domain.chat.manager.TaskManager; import com.sofa.linkiving.domain.chat.service.ChatService; import com.sofa.linkiving.domain.chat.service.FeedbackService; import com.sofa.linkiving.domain.chat.service.MessageService; +import com.sofa.linkiving.domain.chat.service.RagChatService; import com.sofa.linkiving.domain.member.entity.Member; import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; @Service @RequiredArgsConstructor @Transactional(readOnly = true) +@Slf4j public class ChatFacade { private final ChatService chatService; private final MessageService messageService; private final FeedbackService feedbackService; private final AiTitleClient aiTitleClient; + private final RagChatService ragChatService; + private final TaskManager taskManager; + private final SimpMessagingTemplate messagingTemplate; public MessagesRes getMessages(Member member, Long chatId, Long lastId, int size) { Chat chat = chatService.getChat(chatId, member); @@ -57,12 +67,37 @@ public void deleteChat(Member member, Long chatId) { @Transactional public void generateAnswer(Long chatId, Member member, String message) { - Chat chat = chatService.getChat(chatId, member); - messageService.generateAnswer(chat, message); + + CompletableFuture task = ragChatService.generateAnswer(chatId, member, message); + + taskManager.put(chatId, task); + + task.whenComplete((result, ex) -> { + taskManager.remove(chatId); + + if (task.isCancelled() || ex != null) { + sendNotification(chatId, member, AnswerRes.error(chatId, message)); + return; + } + + if (result != null) { + sendNotification(chatId, member, result); + } + }); + } + + private void sendNotification(Long chatId, Member member, AnswerRes res) { + messagingTemplate.convertAndSendToUser( + member.getEmail(), + "/queue/chat", + res + ); } public void cancelAnswer(Long chatId, Member member) { - Chat chat = chatService.getChat(chatId, member); - messageService.cancelAnswer(chat); + if (chatService.existsChat(member, chatId)) { + log.info("Cancelling answer for chat {}", chatId); + taskManager.cancel(chatId); + } } } diff --git a/src/main/java/com/sofa/linkiving/domain/chat/manager/SubscriptionManager.java b/src/main/java/com/sofa/linkiving/domain/chat/manager/SubscriptionManager.java deleted file mode 100644 index 27d83a85..00000000 --- a/src/main/java/com/sofa/linkiving/domain/chat/manager/SubscriptionManager.java +++ /dev/null @@ -1,39 +0,0 @@ -package com.sofa.linkiving.domain.chat.manager; - -import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; - -import org.springframework.stereotype.Component; - -import reactor.core.Disposable; - -@Component -public class SubscriptionManager { - - private final Map activeSubscriptions = new ConcurrentHashMap<>(); - - /** - * 구독 추가 (기존 작업이 있다면 취소 후 등록) - */ - public void add(String key, Disposable subscription) { - cancel(key); // 안전하게 기존 작업 정리 - activeSubscriptions.put(key, subscription); - } - - /** - * 구독 취소 및 자원 해제 - */ - public void cancel(String key) { - Disposable subscription = activeSubscriptions.remove(key); - if (subscription != null && !subscription.isDisposed()) { - subscription.dispose(); - } - } - - /** - * 완료된 구독 제거 (자원 해제 없이 Map에서만 삭제) - */ - public void remove(String key) { - activeSubscriptions.remove(key); - } -} diff --git a/src/main/java/com/sofa/linkiving/domain/chat/manager/TaskManager.java b/src/main/java/com/sofa/linkiving/domain/chat/manager/TaskManager.java new file mode 100644 index 00000000..170c52e1 --- /dev/null +++ b/src/main/java/com/sofa/linkiving/domain/chat/manager/TaskManager.java @@ -0,0 +1,27 @@ +package com.sofa.linkiving.domain.chat.manager; + +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.Future; + +import org.springframework.stereotype.Component; + +@Component +public class TaskManager { + private final Map> activeTasks = new ConcurrentHashMap<>(); + + public void put(Long chatId, Future task) { + activeTasks.put(chatId, task); + } + + public void cancel(Long chatId) { + Future task = activeTasks.remove(chatId); + if (task != null && !task.isDone()) { + task.cancel(true); + } + } + + public void remove(Long chatId) { + activeTasks.remove(chatId); + } +} diff --git a/src/main/java/com/sofa/linkiving/domain/chat/repository/ChatRepository.java b/src/main/java/com/sofa/linkiving/domain/chat/repository/ChatRepository.java index 40bc94e7..0842b0aa 100644 --- a/src/main/java/com/sofa/linkiving/domain/chat/repository/ChatRepository.java +++ b/src/main/java/com/sofa/linkiving/domain/chat/repository/ChatRepository.java @@ -24,4 +24,6 @@ ORDER BY MAX(m.createdAt) DESC List findAllByMemberOrderByLastMessageDesc(@Param("member") Member member); Optional findByIdAndMember(Long id, Member member); + + Boolean existsByIdAndMember(Long id, Member member); } diff --git a/src/main/java/com/sofa/linkiving/domain/chat/repository/MessageRepository.java b/src/main/java/com/sofa/linkiving/domain/chat/repository/MessageRepository.java index 34c6472d..e75d29ec 100644 --- a/src/main/java/com/sofa/linkiving/domain/chat/repository/MessageRepository.java +++ b/src/main/java/com/sofa/linkiving/domain/chat/repository/MessageRepository.java @@ -32,4 +32,6 @@ List findAllByChatAndCursor( ); List findAllByChat(Chat chat); + + List findTop7ByChatAndIdLessThanOrderByIdDesc(Chat chat, Long id); } diff --git a/src/main/java/com/sofa/linkiving/domain/chat/service/ChatQueryService.java b/src/main/java/com/sofa/linkiving/domain/chat/service/ChatQueryService.java index 3de037e0..502db2f3 100644 --- a/src/main/java/com/sofa/linkiving/domain/chat/service/ChatQueryService.java +++ b/src/main/java/com/sofa/linkiving/domain/chat/service/ChatQueryService.java @@ -26,4 +26,8 @@ public Chat findChat(Long chatId, Member member) { public List findAllOrderByLastMessageDesc(Member member) { return chatRepository.findAllByMemberOrderByLastMessageDesc(member); } + + public boolean existsByIdAndMember(Member member, Long chatId) { + return chatRepository.existsByIdAndMember(chatId, member); + } } diff --git a/src/main/java/com/sofa/linkiving/domain/chat/service/ChatService.java b/src/main/java/com/sofa/linkiving/domain/chat/service/ChatService.java index 7bb31b45..8d764435 100644 --- a/src/main/java/com/sofa/linkiving/domain/chat/service/ChatService.java +++ b/src/main/java/com/sofa/linkiving/domain/chat/service/ChatService.java @@ -30,4 +30,8 @@ public List getChats(Member member) { public Chat createChat(String title, Member member) { return chatCommandService.saveChat(title, member); } + + public boolean existsChat(Member member, Long chatId) { + return chatQueryService.existsByIdAndMember(member, chatId); + } } diff --git a/src/main/java/com/sofa/linkiving/domain/chat/service/MessageCommandService.java b/src/main/java/com/sofa/linkiving/domain/chat/service/MessageCommandService.java index 5107e4a0..19b4624c 100644 --- a/src/main/java/com/sofa/linkiving/domain/chat/service/MessageCommandService.java +++ b/src/main/java/com/sofa/linkiving/domain/chat/service/MessageCommandService.java @@ -1,10 +1,14 @@ package com.sofa.linkiving.domain.chat.service; +import java.util.List; + import org.springframework.stereotype.Service; import com.sofa.linkiving.domain.chat.entity.Chat; import com.sofa.linkiving.domain.chat.entity.Message; +import com.sofa.linkiving.domain.chat.enums.Type; import com.sofa.linkiving.domain.chat.repository.MessageRepository; +import com.sofa.linkiving.domain.link.entity.Link; import lombok.RequiredArgsConstructor; @@ -20,4 +24,25 @@ public void deleteAllByChat(Chat chat) { public Message saveMessage(Message message) { return messageRepository.save(message); } + + public Message saveUserMessage(Chat chat, String content) { + Message message = Message.builder() + .chat(chat) + .type(Type.USER) + .content(content) + .build(); + + return messageRepository.save(message); + } + + public Message saveAiMessage(Chat chat, String content, List links) { + Message message = Message.builder() + .chat(chat) + .type(Type.AI) + .content(content) + .links(links) + .build(); + + return messageRepository.save(message); + } } diff --git a/src/main/java/com/sofa/linkiving/domain/chat/service/MessageQueryService.java b/src/main/java/com/sofa/linkiving/domain/chat/service/MessageQueryService.java index 1373eaa1..93de9f0d 100644 --- a/src/main/java/com/sofa/linkiving/domain/chat/service/MessageQueryService.java +++ b/src/main/java/com/sofa/linkiving/domain/chat/service/MessageQueryService.java @@ -30,4 +30,8 @@ public Slice findAllByChatAndCursor(Chat chat, Long lastId, int size) { return new SliceImpl<>(messages, pageRequest, hasNext); } + + public List findTop7ByChatIdAndIdLessThanOrderByIdDesc(Long id, Chat chat) { + return messageRepository.findTop7ByChatAndIdLessThanOrderByIdDesc(chat, id); + } } diff --git a/src/main/java/com/sofa/linkiving/domain/chat/service/MessageService.java b/src/main/java/com/sofa/linkiving/domain/chat/service/MessageService.java index 8ef29b84..59a8d95e 100644 --- a/src/main/java/com/sofa/linkiving/domain/chat/service/MessageService.java +++ b/src/main/java/com/sofa/linkiving/domain/chat/service/MessageService.java @@ -2,26 +2,20 @@ import java.util.List; import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; import org.springframework.data.domain.Slice; -import org.springframework.messaging.simp.SimpMessagingTemplate; import org.springframework.stereotype.Service; -import org.springframework.web.reactive.function.client.WebClient; import com.sofa.linkiving.domain.chat.dto.internal.MessageDto; import com.sofa.linkiving.domain.chat.dto.internal.MessagesDto; import com.sofa.linkiving.domain.chat.entity.Chat; import com.sofa.linkiving.domain.chat.entity.Message; -import com.sofa.linkiving.domain.chat.enums.Type; -import com.sofa.linkiving.domain.chat.manager.SubscriptionManager; import com.sofa.linkiving.domain.link.dto.internal.LinkDto; import com.sofa.linkiving.domain.link.entity.Link; import com.sofa.linkiving.domain.link.entity.Summary; import com.sofa.linkiving.domain.link.service.SummaryQueryService; import lombok.RequiredArgsConstructor; -import reactor.core.Disposable; @Service @RequiredArgsConstructor @@ -30,71 +24,6 @@ public class MessageService { private final MessageQueryService messageQueryService; private final SummaryQueryService summaryQueryService; - private final SimpMessagingTemplate messagingTemplate; - private final SubscriptionManager subscriptionManager; - - private final WebClient webClient = WebClient.create("http://localhost:8080/mock/ai"); - private final Map messageBuffers = new ConcurrentHashMap<>(); - - public void generateAnswer(Chat chat, String userMessage) { - - String roomId = chat.getId().toString(); - - if (messageBuffers.containsKey(roomId)) { - return; - } - - messageBuffers.put(roomId, new StringBuilder()); - - Disposable subscription = webClient.post() - .uri("/generate") - .bodyValue(Map.of("prompt", userMessage)) - .retrieve() - .bodyToFlux(String.class) - .doOnComplete(() -> { - String fullAnswer = messageBuffers.remove(roomId).toString(); - - saveMessage(chat, Type.USER, userMessage); - saveMessage(chat, Type.AI, fullAnswer); - - subscriptionManager.remove(roomId); - messagingTemplate.convertAndSend("/topic/chat/" + roomId, "END_OF_STREAM"); - }) - .doOnError(e -> { - subscriptionManager.remove(roomId); - messagingTemplate.convertAndSend("/topic/chat/" + roomId, "ERROR: " + e.getMessage()); - }) - .subscribe(token -> { - StringBuilder buffer = messageBuffers.get(roomId); - if (buffer != null) { - buffer.append(token); - } - - messagingTemplate.convertAndSend("/topic/chat/" + roomId, token); - }); - - subscriptionManager.add(roomId, subscription); - } - - public void cancelAnswer(Chat chat) { - String roomId = chat.getId().toString(); - - subscriptionManager.cancel(roomId); - messageBuffers.remove(roomId); - - messagingTemplate.convertAndSend("/topic/chat/" + roomId, "GENERATION_CANCELLED"); - } - - private void saveMessage(Chat chat, Type type, String content) { - Message message = Message.builder() - .chat(chat) - .type(type) - .content(content) - .build(); - - messageCommandService.saveMessage(message); - } - public void deleteAll(Chat chat) { messageCommandService.deleteAllByChat(chat); } diff --git a/src/main/java/com/sofa/linkiving/domain/chat/service/RagChatService.java b/src/main/java/com/sofa/linkiving/domain/chat/service/RagChatService.java new file mode 100644 index 00000000..7cffce3e --- /dev/null +++ b/src/main/java/com/sofa/linkiving/domain/chat/service/RagChatService.java @@ -0,0 +1,95 @@ +package com.sofa.linkiving.domain.chat.service; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Objects; +import java.util.concurrent.CompletableFuture; + +import org.springframework.scheduling.annotation.Async; +import org.springframework.stereotype.Service; +import org.springframework.transaction.annotation.Propagation; +import org.springframework.transaction.annotation.Transactional; + +import com.sofa.linkiving.domain.chat.ai.AnswerClient; +import com.sofa.linkiving.domain.chat.dto.request.RagAnswerReq; +import com.sofa.linkiving.domain.chat.dto.response.AnswerRes; +import com.sofa.linkiving.domain.chat.dto.response.RagAnswerRes; +import com.sofa.linkiving.domain.chat.entity.Chat; +import com.sofa.linkiving.domain.chat.entity.Message; +import com.sofa.linkiving.domain.chat.enums.Mode; +import com.sofa.linkiving.domain.link.dto.internal.LinkDto; +import com.sofa.linkiving.domain.link.entity.Link; +import com.sofa.linkiving.domain.link.service.LinkQueryService; +import com.sofa.linkiving.domain.member.entity.Member; + +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; + +@Slf4j +@Service +@RequiredArgsConstructor +public class RagChatService { + private final AnswerClient answerClient; + private final MessageCommandService messageCommandService; + private final MessageQueryService messageQueryService; + private final LinkQueryService linkQueryService; + private final ChatQueryService chatQueryService; + + @Async + @Transactional(propagation = Propagation.NOT_SUPPORTED) + public CompletableFuture generateAnswer(Long chatId, Member member, String userMessage) { + + Chat chat = chatQueryService.findChat(chatId, member); + + Message question = messageCommandService.saveUserMessage(chat, userMessage); + List pastMessages = messageQueryService.findTop7ByChatIdAndIdLessThanOrderByIdDesc( + question.getId(), chat); + + List history = new ArrayList<>(pastMessages.stream() + .map(Message::getContent) + .toList()); + Collections.reverse(history); + + RagAnswerReq request = new RagAnswerReq( + member.getId(), + userMessage, + history, + Mode.DETAILED + ); + + RagAnswerRes res = answerClient.generateAnswer(request); + + String fullAnswer = res.answer(); + + List linkIds = parseLinkIds(res.linkIds()); + List linkDtos = linkQueryService.findAllByIdInWithSummary(linkIds, member); + List links = linkDtos.stream().map(LinkDto::link).toList(); + + List steps = res.reasoningSteps().stream().map(RagAnswerRes.ReasoningStep::step).toList(); + + Message answer = messageCommandService.saveAiMessage(chat, fullAnswer, links); + + AnswerRes payload = AnswerRes.of(chat.getId(), answer, steps, linkDtos); + + return CompletableFuture.completedFuture(payload); + + } + + private List parseLinkIds(List linkIds) { + if (linkIds == null || linkIds.isEmpty()) { + return Collections.emptyList(); + } + return linkIds.stream() + .map(id -> { + try { + return Long.parseLong(id.trim()); + } catch (NumberFormatException e) { + log.warn("AI returned invalid linkId: {}", id); + return null; + } + }) + .filter(Objects::nonNull) + .toList(); + } +} diff --git a/src/main/java/com/sofa/linkiving/domain/link/repository/LinkRepository.java b/src/main/java/com/sofa/linkiving/domain/link/repository/LinkRepository.java index ca90c4c2..7c7c45ce 100644 --- a/src/main/java/com/sofa/linkiving/domain/link/repository/LinkRepository.java +++ b/src/main/java/com/sofa/linkiving/domain/link/repository/LinkRepository.java @@ -52,4 +52,17 @@ List findAllByMemberWithSummaryAndCursorAndIsDeleteFalse( @Param("lastId") Long lastId, Pageable pageable ); + + @Query(""" + SELECT new com.sofa.linkiving.domain.link.dto.internal.LinkDto(l, s) + FROM Link l + LEFT JOIN Summary s ON s.link = l AND s.selected = true + WHERE l.id IN :linkIds + AND l.member = :member + AND l.isDelete = false + """) + List findAllByMemberAndIdInWithSummaryAndIsDeleteFalse( + @Param("linkIds") List linkIds, + @Param("member") Member member + ); } diff --git a/src/main/java/com/sofa/linkiving/domain/link/service/LinkQueryService.java b/src/main/java/com/sofa/linkiving/domain/link/service/LinkQueryService.java index 598a22ba..de28235a 100644 --- a/src/main/java/com/sofa/linkiving/domain/link/service/LinkQueryService.java +++ b/src/main/java/com/sofa/linkiving/domain/link/service/LinkQueryService.java @@ -35,6 +35,10 @@ public LinkDto findByIdWithSummary(Long linkId, Member member) { .orElseThrow(() -> new BusinessException(LinkErrorCode.LINK_NOT_FOUND)); } + public List findAllByIdInWithSummary(List linkIds, Member member) { + return linkRepository.findAllByMemberAndIdInWithSummaryAndIsDeleteFalse(linkIds, member); + } + public LinksDto findAllByMemberWithSummaryAndCursor(Member member, Long lastId, int size) { PageRequest pageRequest = PageRequest.of(0, size + 1); List linkDtos = linkRepository.findAllByMemberWithSummaryAndCursorAndIsDeleteFalse(member, lastId, diff --git a/src/main/java/com/sofa/linkiving/global/config/WebSocketConfig.java b/src/main/java/com/sofa/linkiving/global/config/WebSocketConfig.java index d107c148..af321eeb 100644 --- a/src/main/java/com/sofa/linkiving/global/config/WebSocketConfig.java +++ b/src/main/java/com/sofa/linkiving/global/config/WebSocketConfig.java @@ -6,6 +6,7 @@ import org.springframework.messaging.handler.invocation.HandlerMethodArgumentResolver; import org.springframework.messaging.simp.config.ChannelRegistration; import org.springframework.messaging.simp.config.MessageBrokerRegistry; +import org.springframework.scheduling.annotation.EnableAsync; import org.springframework.web.socket.config.annotation.EnableWebSocketMessageBroker; import org.springframework.web.socket.config.annotation.StompEndpointRegistry; import org.springframework.web.socket.config.annotation.WebSocketMessageBrokerConfigurer; @@ -17,6 +18,7 @@ @Configuration @EnableWebSocketMessageBroker +@EnableAsync @RequiredArgsConstructor public class WebSocketConfig implements WebSocketMessageBrokerConfigurer { @@ -25,8 +27,10 @@ public class WebSocketConfig implements WebSocketMessageBrokerConfigurer { @Override public void configureMessageBroker(MessageBrokerRegistry config) { - config.enableSimpleBroker("/topic/chat"); + config.enableSimpleBroker("/topic", "/queue"); config.setApplicationDestinationPrefixes("/ws/chat"); + + config.setUserDestinationPrefix("/user"); } @Override diff --git a/src/test/java/com/sofa/linkiving/domain/chat/ai/MockAnswerClientTest.java b/src/test/java/com/sofa/linkiving/domain/chat/ai/MockAnswerClientTest.java new file mode 100644 index 00000000..43b52ed4 --- /dev/null +++ b/src/test/java/com/sofa/linkiving/domain/chat/ai/MockAnswerClientTest.java @@ -0,0 +1,56 @@ +package com.sofa.linkiving.domain.chat.ai; + +import static org.assertj.core.api.Assertions.*; + +import java.util.Collections; + +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.InjectMocks; +import org.mockito.junit.jupiter.MockitoExtension; + +import com.sofa.linkiving.domain.chat.dto.request.RagAnswerReq; +import com.sofa.linkiving.domain.chat.dto.response.RagAnswerRes; +import com.sofa.linkiving.domain.chat.enums.Mode; + +@ExtendWith(MockitoExtension.class) +@DisplayName("MockAnswerClient 단위 테스트") +public class MockAnswerClientTest { + @InjectMocks + private MockAnswerClient mockAnswerClient; + + @Test + @DisplayName("입력값과 관계없이 항상 고정된 Gemini 관련 답변과 메타데이터를 반환한다") + void shouldReturnFixedAnswer() { + RagAnswerReq req = new RagAnswerReq( + 1L, + "테스트 질문", + Collections.emptyList(), + Mode.DETAILED + ); + + // when + RagAnswerRes res = mockAnswerClient.generateAnswer(req); + + // then + assertThat(res).isNotNull(); + + assertThat(res.answer()).contains("임시 답변"); + + assertThat(res.linkIds()) + .hasSize(2) + .containsExactly("3", "4"); + + assertThat(res.reasoningSteps()).hasSize(1); + RagAnswerRes.ReasoningStep step = res.reasoningSteps().get(0); + assertThat(step.step()).contains("임시 답변 스탭"); + assertThat(step.linkIds()).containsExactly("3", "4"); + + assertThat(res.relatedLinks()) + .hasSize(2) + .containsExactly("3", "4"); + + assertThat(res.isFallback()).isFalse(); + } +} diff --git a/src/test/java/com/sofa/linkiving/domain/chat/facade/ChatFacadeTest.java b/src/test/java/com/sofa/linkiving/domain/chat/facade/ChatFacadeTest.java index 9b06f688..c0c4adc5 100644 --- a/src/test/java/com/sofa/linkiving/domain/chat/facade/ChatFacadeTest.java +++ b/src/test/java/com/sofa/linkiving/domain/chat/facade/ChatFacadeTest.java @@ -5,6 +5,7 @@ import java.util.Collections; import java.util.List; +import java.util.concurrent.CompletableFuture; import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Test; @@ -12,16 +13,20 @@ import org.mockito.InjectMocks; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import org.springframework.messaging.simp.SimpMessagingTemplate; import com.sofa.linkiving.domain.chat.ai.AiTitleClient; import com.sofa.linkiving.domain.chat.dto.internal.MessagesDto; +import com.sofa.linkiving.domain.chat.dto.response.AnswerRes; import com.sofa.linkiving.domain.chat.dto.response.ChatsRes; import com.sofa.linkiving.domain.chat.dto.response.CreateChatRes; import com.sofa.linkiving.domain.chat.dto.response.MessagesRes; import com.sofa.linkiving.domain.chat.entity.Chat; +import com.sofa.linkiving.domain.chat.manager.TaskManager; import com.sofa.linkiving.domain.chat.service.ChatService; import com.sofa.linkiving.domain.chat.service.FeedbackService; import com.sofa.linkiving.domain.chat.service.MessageService; +import com.sofa.linkiving.domain.chat.service.RagChatService; import com.sofa.linkiving.domain.member.entity.Member; @ExtendWith(MockitoExtension.class) @@ -41,6 +46,15 @@ public class ChatFacadeTest { @Mock private AiTitleClient aiTitleClient; + @Mock + private RagChatService ragChatService; + + @Mock + private TaskManager taskManager; + + @Mock + private SimpMessagingTemplate messagingTemplate; + @Mock private Member member; @@ -140,4 +154,128 @@ void shouldDeleteAllRelatedDataWhenDeleteChat() { // 3. 채팅방 삭제 호출 확인 verify(chatService).delete(chat); } + + @Test + @DisplayName("답변 생성이 성공하면 TaskManager에서 제거하고 성공 알림 전송") + void shouldSendNotificationWhenAnswerGeneratedSuccessfully() { + // given + Long chatId = 1L; + String userMessage = "질문입니다"; + member = mock(Member.class); + given(member.getEmail()).willReturn("test@test.com"); + + CompletableFuture future = new CompletableFuture<>(); + + given(ragChatService.generateAnswer(chatId, member, userMessage)).willReturn(future); + + // when + chatFacade.generateAnswer(chatId, member, userMessage); + + // then + verify(taskManager).put(chatId, future); + + AnswerRes successRes = mock(AnswerRes.class); + future.complete(successRes); + + verify(taskManager).remove(chatId); + + verify(messagingTemplate).convertAndSendToUser( + eq(member.getEmail()), + eq("/queue/chat"), + eq(successRes) + ); + } + + @Test + @DisplayName("답변 생성 중 예외가 발생하면 에러 알림 전송") + void shouldSendErrorNotificationWhenExceptionOccurs() { + // given + Long chatId = 1L; + String userMessage = "질문입니다"; + member = mock(Member.class); + given(member.getEmail()).willReturn("test@test.com"); + + CompletableFuture future = new CompletableFuture<>(); + given(ragChatService.generateAnswer(chatId, member, userMessage)).willReturn(future); + + // when + chatFacade.generateAnswer(chatId, member, userMessage); + + // then + verify(taskManager).put(chatId, future); + + // 2. 비동기 작업 완료 시뮬레이션 (예외 발생) + future.completeExceptionally(new RuntimeException("AI Server Error")); + + // 3. 콜백 실행 후 TaskManager 제거 및 에러 전송 확인 + verify(taskManager).remove(chatId); + + // 에러 발생 시 AnswerRes.error(...)가 전송되어야 함 + verify(messagingTemplate).convertAndSendToUser( + eq(member.getEmail()), + eq("/queue/chat"), + any(AnswerRes.class) // AnswerRes.error() 결과 + ); + } + + @Test + @DisplayName("작업이 취소되면 에러 알림 전송") + void shouldSendErrorNotificationWhenTaskIsCancelled() { + // given + Long chatId = 1L; + String userMessage = "질문입니다"; + member = mock(Member.class); + given(member.getEmail()).willReturn("test@test.com"); + + CompletableFuture future = new CompletableFuture<>(); + given(ragChatService.generateAnswer(chatId, member, userMessage)).willReturn(future); + + // when + chatFacade.generateAnswer(chatId, member, userMessage); + + // 2. 작업 취소 시뮬레이션 + future.cancel(true); + + // then + verify(taskManager).remove(chatId); + + // 취소 상태일 때도 에러 메시지 전송 로직을 타는지 확인 + verify(messagingTemplate).convertAndSendToUser( + eq(member.getEmail()), + eq("/queue/chat"), + any(AnswerRes.class) + ); + } + + @Test + @DisplayName("존재하는 채팅방인 경우 TaskManager에 취소 요청") + void shouldCancelTaskWhenChatExists() { + // given + Long chatId = 1L; + member = mock(Member.class); + + given(chatService.existsChat(member, chatId)).willReturn(true); + + // when + chatFacade.cancelAnswer(chatId, member); + + // then + verify(taskManager).cancel(chatId); + } + + @Test + @DisplayName("존재하지 않는 채팅방인 경우 아무 작업도 하지 않음") + void shouldNotCancelTaskWhenChatDoesNotExist() { + // given + Long chatId = 1L; + member = mock(Member.class); + + given(chatService.existsChat(member, chatId)).willReturn(false); + + // when + chatFacade.cancelAnswer(chatId, member); + + // then + verify(taskManager, never()).cancel(anyLong()); + } } diff --git a/src/test/java/com/sofa/linkiving/domain/chat/integration/WebSocketChatIntegrationTest.java b/src/test/java/com/sofa/linkiving/domain/chat/integration/WebSocketChatIntegrationTest.java index 4a75433b..c6e4200e 100644 --- a/src/test/java/com/sofa/linkiving/domain/chat/integration/WebSocketChatIntegrationTest.java +++ b/src/test/java/com/sofa/linkiving/domain/chat/integration/WebSocketChatIntegrationTest.java @@ -1,30 +1,33 @@ package com.sofa.linkiving.domain.chat.integration; +import static java.util.concurrent.TimeUnit.*; import static org.assertj.core.api.Assertions.*; +import static org.mockito.BDDMockito.*; import java.lang.reflect.Type; +import java.util.ArrayList; +import java.util.Collections; import java.util.List; -import java.util.UUID; import java.util.concurrent.BlockingQueue; +import java.util.concurrent.ExecutionException; import java.util.concurrent.LinkedBlockingQueue; -import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; -import org.jetbrains.annotations.NotNull; +import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Test; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.test.context.SpringBootTest; import org.springframework.boot.test.web.server.LocalServerPort; -import org.springframework.messaging.converter.StringMessageConverter; +import org.springframework.messaging.converter.MappingJackson2MessageConverter; import org.springframework.messaging.simp.stomp.StompFrameHandler; import org.springframework.messaging.simp.stomp.StompHeaders; import org.springframework.messaging.simp.stomp.StompSession; import org.springframework.messaging.simp.stomp.StompSessionHandlerAdapter; +import org.springframework.test.annotation.DirtiesContext; import org.springframework.test.context.ActiveProfiles; import org.springframework.test.context.bean.override.mockito.MockitoBean; -import org.springframework.test.util.ReflectionTestUtils; -import org.springframework.web.reactive.function.client.WebClient; import org.springframework.web.socket.WebSocketHttpHeaders; import org.springframework.web.socket.client.standard.StandardWebSocketClient; import org.springframework.web.socket.messaging.WebSocketStompClient; @@ -32,153 +35,192 @@ import org.springframework.web.socket.sockjs.client.Transport; import org.springframework.web.socket.sockjs.client.WebSocketTransport; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule; +import com.sofa.linkiving.domain.chat.ai.AnswerClient; +import com.sofa.linkiving.domain.chat.dto.request.AnswerCancelReq; +import com.sofa.linkiving.domain.chat.dto.request.AnswerReq; +import com.sofa.linkiving.domain.chat.dto.response.AnswerRes; +import com.sofa.linkiving.domain.chat.dto.response.RagAnswerRes; import com.sofa.linkiving.domain.chat.entity.Chat; import com.sofa.linkiving.domain.chat.repository.ChatRepository; -import com.sofa.linkiving.domain.chat.service.MessageService; +import com.sofa.linkiving.domain.chat.repository.MessageRepository; +import com.sofa.linkiving.domain.link.repository.LinkRepository; +import com.sofa.linkiving.domain.link.repository.SummaryRepository; import com.sofa.linkiving.domain.member.entity.Member; import com.sofa.linkiving.domain.member.repository.MemberRepository; import com.sofa.linkiving.infra.redis.RedisService; import com.sofa.linkiving.security.jwt.JwtTokenProvider; @SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT) +@DirtiesContext(classMode = DirtiesContext.ClassMode.AFTER_EACH_TEST_METHOD) @ActiveProfiles("test") public class WebSocketChatIntegrationTest { @LocalServerPort private int port; - private WebSocketStompClient stompClient; + @Autowired + private MemberRepository memberRepository; @Autowired - private MessageService messageService; + private ChatRepository chatRepository; @Autowired - private MemberRepository memberRepository; + private MessageRepository messageRepository; @Autowired - private ChatRepository chatRepository; + private LinkRepository linkRepository; + + @Autowired + private SummaryRepository summaryRepository; @MockitoBean private RedisService redisService; @Autowired - private JwtTokenProvider jwtTokenProvider; // 실제 토큰 생성 로직 사용 (또는 MockBean) + private JwtTokenProvider jwtTokenProvider; + + @MockitoBean + private AnswerClient answerClient; - private Chat savedChat; - private String validToken; + private StompSession stompSession; + private BlockingQueue blockingQueue; + private Chat testChat; @BeforeEach - void setUp() { - // 1. WebSocket Client 설정 - StandardWebSocketClient standardWebSocketClient = new StandardWebSocketClient(); - WebSocketTransport webSocketTransport = new WebSocketTransport(standardWebSocketClient); - List transports = List.of(webSocketTransport); - SockJsClient sockJsClient = new SockJsClient(transports); - - stompClient = new WebSocketStompClient(sockJsClient); - stompClient.setMessageConverter(new StringMessageConverter()); - - // 2. MockAiController 연결을 위한 WebClient 주소 조작 (핵심) - String testUrl = "http://localhost:" + port + "/mock/ai"; - WebClient testWebClient = WebClient.create(testUrl); - ReflectionTestUtils.setField(messageService, "webClient", testWebClient); - - // 3. 테스트 데이터 생성 - String uniqueEmail = "socket_" + UUID.randomUUID().toString().substring(0, 8) + "@test.com"; - - Member savedMember = memberRepository.save(Member.builder() - .email(uniqueEmail) + void setUp() throws ExecutionException, InterruptedException, TimeoutException { + messageRepository.deleteAllInBatch(); + chatRepository.deleteAllInBatch(); + summaryRepository.deleteAllInBatch(); + linkRepository.deleteAllInBatch(); + memberRepository.deleteAllInBatch(); + + // 1. 데이터 초기화 + Member testMember = memberRepository.save(Member.builder() + .email("test@test.com") .password("password") .build()); - savedChat = chatRepository.save(Chat.builder() - .member(savedMember) - .title("test") + testChat = chatRepository.save(Chat.builder() + .member(testMember) + .title("테스트 채팅방") .build()); - // 4. 유효한 토큰 생성 (StompHandler 통과용) - validToken = jwtTokenProvider.createAccessToken(savedMember.getEmail()); - } + RagAnswerRes defaultRes = new RagAnswerRes( + "Gemini와 관련된 내용은 두 개의 아티클에 포함돼 있습니다.", + List.of("3", "4"), + List.of(new RagAnswerRes.ReasoningStep("제공된 컨텍스트 중...", List.of("3", "4"))), + List.of("3", "4"), + false + ); + given(answerClient.generateAnswer(any())).willReturn(defaultRes); - @Test - @DisplayName("메시지 전송 시 MockAiController를 통해 스트리밍 답변 수신") - void shouldReceiveStreamingResponseWhenSendMessage() throws Exception { - // given - String wsUrl = String.format("ws://localhost:%d/ws/chat", port); + // 2. STOMP 클라이언트 설정 + WebSocketStompClient stompClient = new WebSocketStompClient(new SockJsClient(createTransportClient())); + + MappingJackson2MessageConverter converter = new MappingJackson2MessageConverter(); + converter.setObjectMapper(new ObjectMapper().registerModule(new JavaTimeModule())); + stompClient.setMessageConverter(converter); + + this.blockingQueue = new LinkedBlockingQueue<>(); + + // 3. WebSocket 연결 + String wsUrl = "ws://localhost:" + port + "/ws/chat"; StompHeaders headers = new StompHeaders(); + + String validToken = jwtTokenProvider.createAccessToken(testMember.getEmail()); headers.add("Authorization", "Bearer " + validToken); - WebSocketHttpHeaders handshakeHeaders = new WebSocketHttpHeaders(); + this.stompSession = stompClient.connectAsync( + wsUrl, + new WebSocketHttpHeaders(), + headers, + new StompSessionHandlerAdapter() { + } + ).get(5, SECONDS); + } - StompSession session = stompClient.connectAsync(wsUrl, handshakeHeaders, headers, - new StompSessionHandlerAdapter() { - }) - .get(5, TimeUnit.SECONDS); + @AfterEach + void tearDown() { + messageRepository.deleteAllInBatch(); + chatRepository.deleteAllInBatch(); + summaryRepository.deleteAllInBatch(); + linkRepository.deleteAllInBatch(); + memberRepository.deleteAllInBatch(); + } - Long chatId = savedChat.getId(); - String userMessage = "테스트 질문"; - BlockingQueue queue = new LinkedBlockingQueue<>(); + private List createTransportClient() { + List transports = new ArrayList<>(); + transports.add(new WebSocketTransport(new StandardWebSocketClient())); + return transports; + } - // when: 구독 (/topic/chat/{chatId}) - session.subscribe("/topic/chat/" + chatId, new StompFrameHandler() { - @NotNull + private void subscribeToChatQueue() { + stompSession.subscribe("/user/queue/chat", new StompFrameHandler() { @Override - public Type getPayloadType(@NotNull StompHeaders headers) { - return String.class; + public Type getPayloadType(StompHeaders headers) { + return AnswerRes.class; } @Override - public void handleFrame(@NotNull StompHeaders headers, Object payload) { - queue.add((String)payload); + public void handleFrame(StompHeaders headers, Object payload) { + blockingQueue.offer((AnswerRes)payload); } }); - - // when: 메시지 전송 - session.send("/ws/chat/send/" + chatId, userMessage); - - // then: MockAiController가 보내는 응답 검증 - String response = queue.poll(5, TimeUnit.SECONDS); - - assertThat(response).isNotNull(); - assertThat(response).startsWith("안"); } @Test - @DisplayName("취소 요청 시 GENERATION_CANCELLED 메시지 수신") - void shouldReceiveCancelledMessageWhenCancelRequest() throws Exception { + @DisplayName("유저가 메시지를 보내면 AI 응답이 Queue로 수신된다") + void shouldReceiveAnswerWhenMessageSent() throws InterruptedException { // given - String wsUrl = String.format("ws://localhost:%d/ws/chat", port); - StompHeaders headers = new StompHeaders(); - headers.add("Authorization", "Bearer " + validToken); + Long chatId = testChat.getId(); + String userMessage = "Gemini에 대해 알려줘"; + AnswerReq req = new AnswerReq(chatId, userMessage); - WebSocketHttpHeaders handshakeHeaders = new WebSocketHttpHeaders(); + subscribeToChatQueue(); - StompSession session = stompClient.connectAsync(wsUrl, handshakeHeaders, headers, - new StompSessionHandlerAdapter() { - }) - .get(5, TimeUnit.SECONDS); + // when + stompSession.send("/ws/chat/send", req); - Long chatId = savedChat.getId(); - BlockingQueue queue = new LinkedBlockingQueue<>(); + // then + AnswerRes received = blockingQueue.poll(10, SECONDS); - session.subscribe("/topic/chat/" + chatId, new StompFrameHandler() { - @NotNull - @Override - public Type getPayloadType(@NotNull StompHeaders headers) { - return String.class; - } + assertThat(received).isNotNull(); + assertThat(received.chatId()).isEqualTo(chatId); + assertThat(received.success()).isTrue(); + assertThat(received.content()).contains("Gemini와 관련된 내용"); + } - @Override - public void handleFrame(@NotNull StompHeaders headers, Object payload) { - queue.add((String)payload); - } + @Test + @DisplayName("답변 생성 중 취소 요청 시, 작업이 중단되고 실패 메시지가 수신된다") + void shouldReceiveErrorMessageWhenCancelled() throws InterruptedException { + // given + Long chatId = testChat.getId(); + String userMessage = "취소될 질문"; + AnswerReq sendReq = new AnswerReq(chatId, userMessage); + AnswerCancelReq cancelReq = new AnswerCancelReq(chatId); + + given(answerClient.generateAnswer(any())).willAnswer(invocation -> { + Thread.sleep(500); + return new RagAnswerRes("지연된 답변", Collections.emptyList(), Collections.emptyList(), Collections.emptyList(), + true); }); - // when: 취소 요청 전송 - session.send("/ws/chat/cancel/" + chatId, ""); + subscribeToChatQueue(); + + // when + stompSession.send("/ws/chat/send", sendReq); + Thread.sleep(50); + stompSession.send("/ws/chat/cancel", cancelReq); // then - String response = queue.poll(5, TimeUnit.SECONDS); - assertThat(response).isEqualTo("GENERATION_CANCELLED"); + AnswerRes received = blockingQueue.poll(5, SECONDS); + + assertThat(received).isNotNull(); + assertThat(received.chatId()).isEqualTo(chatId); + assertThat(received.success()).isFalse(); + assertThat(received.content()).isEqualTo(userMessage); } + } diff --git a/src/test/java/com/sofa/linkiving/domain/chat/manager/SubscriptionManagerTest.java b/src/test/java/com/sofa/linkiving/domain/chat/manager/SubscriptionManagerTest.java deleted file mode 100644 index e77446c3..00000000 --- a/src/test/java/com/sofa/linkiving/domain/chat/manager/SubscriptionManagerTest.java +++ /dev/null @@ -1,67 +0,0 @@ -package com.sofa.linkiving.domain.chat.manager; - -import static org.mockito.Mockito.*; - -import org.junit.jupiter.api.DisplayName; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.extension.ExtendWith; -import org.mockito.InjectMocks; -import org.mockito.Mock; -import org.mockito.junit.jupiter.MockitoExtension; - -import reactor.core.Disposable; - -@ExtendWith(MockitoExtension.class) -public class SubscriptionManagerTest { - - @InjectMocks - private SubscriptionManager subscriptionManager; - - @Mock - private Disposable disposable; - - @Test - @DisplayName("구독 추가 요청 시 기존 구독이 있다면 취소 후 등록") - void shouldDisposeOldSubscriptionWhenAdd() { - // given - String key = "chat-1"; - Disposable oldDisposable = mock(Disposable.class); - - // 먼저 하나 등록 - subscriptionManager.add(key, oldDisposable); - - // when: 같은 키로 새로운 구독 등록 - subscriptionManager.add(key, disposable); - - // then: 이전 구독은 dispose 되어야 함 - verify(oldDisposable).dispose(); - } - - @Test - @DisplayName("구독 취소 요청 시 dispose 호출 및 제거") - void shouldDisposeWhenCancel() { - // given - String key = "chat-1"; - subscriptionManager.add(key, disposable); - - // when - subscriptionManager.cancel(key); - - // then - verify(disposable).dispose(); - } - - @Test - @DisplayName("완료된 구독 제거 요청 시 dispose 없이 맵에서만 제거") - void shouldNotDisposeWhenRemove() { - // given - String key = "chat-1"; - subscriptionManager.add(key, disposable); - - // when - subscriptionManager.remove(key); - - // then: remove는 dispose를 호출하지 않음 (이미 완료된 상태 가정) - verify(disposable, never()).dispose(); - } -} diff --git a/src/test/java/com/sofa/linkiving/domain/chat/repository/ChatRepositoryTest.java b/src/test/java/com/sofa/linkiving/domain/chat/repository/ChatRepositoryTest.java index 24f20a19..45f8914c 100644 --- a/src/test/java/com/sofa/linkiving/domain/chat/repository/ChatRepositoryTest.java +++ b/src/test/java/com/sofa/linkiving/domain/chat/repository/ChatRepositoryTest.java @@ -145,4 +145,69 @@ void shouldReturnEmptyWhenChatIsNotMine() { // then assertThat(result).isEmpty(); // 조회되면 안 됨 (보안 검증) } + + @Test + @DisplayName("내 채팅방인 경우 true 반환") + void shouldReturnTrue_WhenChatExistsAndBelongsToMember() { + // given + Member me = memberRepository.save(Member.builder() + .email("me@test.com") + .password("password") + .build()); + + Chat myChat = chatRepository.save(Chat.builder() + .member(me) + .title("My Chat") + .build()); + + // when + boolean exists = chatRepository.existsByIdAndMember(myChat.getId(), me); + + // then + assertThat(exists).isTrue(); + } + + @Test + @DisplayName("다른 사람의 채팅방인 경우 false 반환") + void shouldReturnFalse_WhenChatBelongsToOtherMember() { + // given + Member me = memberRepository.save(Member.builder() + .email("me@test.com") + .password("password") + .build()); + + Member other = memberRepository.save(Member.builder() + .email("other@test.com") + .password("password") + .build()); + + Chat othersChat = chatRepository.save(Chat.builder() + .member(other) + .title("Other's Chat") + .build()); + + // when + boolean exists = chatRepository.existsByIdAndMember(othersChat.getId(), me); + + // then + assertThat(exists).isFalse(); + } + + @Test + @DisplayName(" 존재하지 않는 채팅방 ID인 경우 false 반환") + void shouldReturnFalse_WhenChatDoesNotExist() { + // given + Member me = memberRepository.save(Member.builder() + .email("me@test.com") + .password("password") + .build()); + + Long nonExistentChatId = 9999L; + + // when + boolean exists = chatRepository.existsByIdAndMember(nonExistentChatId, me); + + // then + assertThat(exists).isFalse(); + } } diff --git a/src/test/java/com/sofa/linkiving/domain/chat/repository/MessageRepositoryTest.java b/src/test/java/com/sofa/linkiving/domain/chat/repository/MessageRepositoryTest.java index a6d7ce23..8017fb56 100644 --- a/src/test/java/com/sofa/linkiving/domain/chat/repository/MessageRepositoryTest.java +++ b/src/test/java/com/sofa/linkiving/domain/chat/repository/MessageRepositoryTest.java @@ -194,4 +194,53 @@ void shouldReturnMessageWithFeedback() { assertThat(result.get(0).getFeedback()).isNotNull(); assertThat(result.get(0).getFeedback().getText()).isEqualTo("Good"); } + + @Test + @DisplayName(" 특정 ID보다 작은 메시지 중 최신 7개를 내림차순으로 조회") + void shouldReturnTop7MessagesBeforeGivenIdDesc() { + // given + for (int i = 1; i <= 15; i++) { + messageRepository.save(Message.builder() + .chat(chat) + .content("Message " + i) + .type(Type.USER) + .build()); + } + + List allMessages = messageRepository.findAll(); + Long targetId = allMessages.get(10).getId(); + + // when + List result = messageRepository.findTop7ByChatAndIdLessThanOrderByIdDesc(chat, targetId); + + // then + assertThat(result).hasSize(7); + assertThat(result).allMatch(msg -> msg.getId() < targetId); + assertThat(result.get(0).getId()).isGreaterThan(result.get(1).getId()); + assertThat(result.get(0).getContent()).isEqualTo(allMessages.get(9).getContent()); // ID 10 + assertThat(result.get(6).getContent()).isEqualTo(allMessages.get(3).getContent()); // ID 4 + } + + @Test + @DisplayName(" 조건에 맞는 메시지가 7개 미만이면 전체 반환") + void shouldReturnAllMessages_WhenLessThan7() { + // given + for (int i = 1; i <= 5; i++) { + messageRepository.save(Message.builder() + .chat(chat) + .content("Message " + i) + .type(Type.USER) + .build()); + } + + List allMessages = messageRepository.findAll(); + Long targetId = allMessages.get(4).getId() + 100L; + + // when + List result = messageRepository.findTop7ByChatAndIdLessThanOrderByIdDesc(chat, targetId); + + // then + assertThat(result).hasSize(5); + assertThat(result.get(0).getId()).isGreaterThan(result.get(4).getId()); // 정렬 확인 + } } diff --git a/src/test/java/com/sofa/linkiving/domain/chat/service/ChatQueryServiceTest.java b/src/test/java/com/sofa/linkiving/domain/chat/service/ChatQueryServiceTest.java index 607c5c50..a6f40868 100644 --- a/src/test/java/com/sofa/linkiving/domain/chat/service/ChatQueryServiceTest.java +++ b/src/test/java/com/sofa/linkiving/domain/chat/service/ChatQueryServiceTest.java @@ -89,4 +89,20 @@ void shouldThrowExceptionWhenChatNotFound() { .isInstanceOf(BusinessException.class) .hasFieldOrPropertyWithValue("errorCode", ChatErrorCode.CHAT_NOT_FOUND); } + + @Test + @DisplayName("리포지토리를 호출하여 존재 여부를 반환한다") + void shouldDelegateToRepository() { + // given + Member member = mock(Member.class); + Long chatId = 1L; + given(chatRepository.existsByIdAndMember(chatId, member)).willReturn(false); + + // when + boolean result = chatQueryService.existsByIdAndMember(member, chatId); + + // then + assertThat(result).isFalse(); + verify(chatRepository).existsByIdAndMember(chatId, member); + } } diff --git a/src/test/java/com/sofa/linkiving/domain/chat/service/ChatServiceTest.java b/src/test/java/com/sofa/linkiving/domain/chat/service/ChatServiceTest.java index b89a6b19..b6358c41 100644 --- a/src/test/java/com/sofa/linkiving/domain/chat/service/ChatServiceTest.java +++ b/src/test/java/com/sofa/linkiving/domain/chat/service/ChatServiceTest.java @@ -77,4 +77,20 @@ void shouldReturnChatWhenGetChat() { assertThat(result).isEqualTo(chat); verify(chatQueryService).findChat(chatId, member); } + + @Test + @DisplayName("QueryService를 호출하여 채팅방 존재 여부 반환") + void shouldDelegateToQueryService() { + // given + Member member = mock(Member.class); + Long chatId = 1L; + given(chatQueryService.existsByIdAndMember(member, chatId)).willReturn(true); + + // when + boolean result = chatService.existsChat(member, chatId); + + // then + assertThat(result).isTrue(); + verify(chatQueryService).existsByIdAndMember(member, chatId); + } } diff --git a/src/test/java/com/sofa/linkiving/domain/chat/service/MessageCommandServiceTest.java b/src/test/java/com/sofa/linkiving/domain/chat/service/MessageCommandServiceTest.java index ff3568c9..8e26a6ca 100644 --- a/src/test/java/com/sofa/linkiving/domain/chat/service/MessageCommandServiceTest.java +++ b/src/test/java/com/sofa/linkiving/domain/chat/service/MessageCommandServiceTest.java @@ -3,16 +3,21 @@ import static org.assertj.core.api.Assertions.*; import static org.mockito.BDDMockito.*; +import java.util.List; + import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.ArgumentCaptor; import org.mockito.InjectMocks; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; import com.sofa.linkiving.domain.chat.entity.Chat; import com.sofa.linkiving.domain.chat.entity.Message; +import com.sofa.linkiving.domain.chat.enums.Type; import com.sofa.linkiving.domain.chat.repository.MessageRepository; +import com.sofa.linkiving.domain.link.entity.Link; @ExtendWith(MockitoExtension.class) public class MessageCommandServiceTest { @@ -50,4 +55,51 @@ void shouldCallDeleteAllByChatWhenDeleteAllByChat() { // then verify(messageRepository).deleteAllByChat(chat); } + + @Test + @DisplayName("USER 타입의 메시지를 생성하고 저장한다") + void shouldSaveUserMessageCorrectly() { + // given + Chat chat = mock(Chat.class); + String content = "유저 질문"; + + // save 호출 시 입력된 객체를 그대로 반환하도록 설정 + given(messageRepository.save(any(Message.class))).willAnswer(invocation -> invocation.getArgument(0)); + + // when + Message savedMessage = messageCommandService.saveUserMessage(chat, content); + + // then + ArgumentCaptor captor = ArgumentCaptor.forClass(Message.class); + verify(messageRepository).save(captor.capture()); + + Message captured = captor.getValue(); + assertThat(captured.getChat()).isEqualTo(chat); + assertThat(captured.getContent()).isEqualTo(content); + assertThat(captured.getType()).isEqualTo(Type.USER); + } + + @Test + @DisplayName("AI 타입의 메시지와 링크 정보를 저장한다") + void shouldSaveAiMessageCorrectly() { + // given + Chat chat = mock(Chat.class); + String content = "AI 답변"; + List links = List.of(mock(Link.class)); + + given(messageRepository.save(any(Message.class))).willAnswer(invocation -> invocation.getArgument(0)); + + // when + Message savedMessage = messageCommandService.saveAiMessage(chat, content, links); + + // then + ArgumentCaptor captor = ArgumentCaptor.forClass(Message.class); + verify(messageRepository).save(captor.capture()); + + Message captured = captor.getValue(); + assertThat(captured.getChat()).isEqualTo(chat); + assertThat(captured.getContent()).isEqualTo(content); + assertThat(captured.getLinks()).isEqualTo(links); + assertThat(captured.getType()).isEqualTo(Type.AI); + } } diff --git a/src/test/java/com/sofa/linkiving/domain/chat/service/MessageQueryServiceTest.java b/src/test/java/com/sofa/linkiving/domain/chat/service/MessageQueryServiceTest.java index ac0de736..c52caecd 100644 --- a/src/test/java/com/sofa/linkiving/domain/chat/service/MessageQueryServiceTest.java +++ b/src/test/java/com/sofa/linkiving/domain/chat/service/MessageQueryServiceTest.java @@ -76,4 +76,24 @@ void shouldReturnHasNextFalseWhenNoMoreData() { assertThat(result.hasNext()).isFalse(); assertThat(result.getContent()).hasSize(size); } + + @Test + @DisplayName("리포지토리를 호출하여 조건에 맞는 메시지 목록을 반환한다") + void shouldReturnMessages_WhenFound() { + // given + Long lastId = 100L; + Chat chat = mock(Chat.class); + Message message = mock(Message.class); + List expectedMessages = List.of(message); + + given(messageRepository.findTop7ByChatAndIdLessThanOrderByIdDesc(chat, lastId)) + .willReturn(expectedMessages); + + // when + List result = messageQueryService.findTop7ByChatIdAndIdLessThanOrderByIdDesc(lastId, chat); + + // then + assertThat(result).isEqualTo(expectedMessages); + verify(messageRepository).findTop7ByChatAndIdLessThanOrderByIdDesc(chat, lastId); + } } diff --git a/src/test/java/com/sofa/linkiving/domain/chat/service/MessageServiceTest.java b/src/test/java/com/sofa/linkiving/domain/chat/service/MessageServiceTest.java index 8518b121..247f2e01 100644 --- a/src/test/java/com/sofa/linkiving/domain/chat/service/MessageServiceTest.java +++ b/src/test/java/com/sofa/linkiving/domain/chat/service/MessageServiceTest.java @@ -7,7 +7,6 @@ import java.util.List; import java.util.Map; -import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Test; @@ -18,12 +17,10 @@ import org.springframework.data.domain.Slice; import org.springframework.data.domain.SliceImpl; import org.springframework.messaging.simp.SimpMessagingTemplate; -import org.springframework.test.util.ReflectionTestUtils; import com.sofa.linkiving.domain.chat.dto.internal.MessagesDto; import com.sofa.linkiving.domain.chat.entity.Chat; import com.sofa.linkiving.domain.chat.entity.Message; -import com.sofa.linkiving.domain.chat.manager.SubscriptionManager; import com.sofa.linkiving.domain.link.entity.Link; import com.sofa.linkiving.domain.link.entity.Summary; import com.sofa.linkiving.domain.link.service.SummaryQueryService; @@ -49,9 +46,6 @@ public class MessageServiceTest { @Mock private SimpMessagingTemplate messagingTemplate; - @Mock - private SubscriptionManager subscriptionManager; - @BeforeEach void setUp() { lenient().when(chat.getId()).thenReturn(1L); @@ -175,37 +169,4 @@ void shouldRequestSummariesForDistinctLinks() { list.size() == 1 // 두 메시지에 링크가 총 2개지만, 같은 객체이므로 1개로 줄어야 함 )); } - - @Test - @DisplayName("답변 취소 요청 시 구독 취소 및 취소 메시지 전송") - void shouldCancelSubscriptionAndSendMessageWhenCancelAnswer() { - // given - String roomId = "1"; - - // when - messageService.cancelAnswer(chat); - - // then - verify(subscriptionManager).cancel(roomId); - verify(messagingTemplate).convertAndSend(eq("/topic/chat/" + roomId), eq("GENERATION_CANCELLED")); - } - - @Test - @DisplayName("이미 답변 생성 중일 경우 중복 요청 무시") - void shouldIgnoreRequestWhenAlreadyGenerating() { - // given - // messageBuffers 필드에 강제로 현재 채팅방 ID를 넣어 생성 중인 상태로 만듦 - @SuppressWarnings("unchecked") - Map buffers = (Map)ReflectionTestUtils.getField(messageService, - "messageBuffers"); - Assertions.assertNotNull(buffers); - buffers.put("1", new StringBuilder()); - - // when - messageService.generateAnswer(chat, "질문"); - - // then - // WebClient 호출 로직으로 넘어가지 않아야 하므로 SubscriptionManager 호출이 없어야 함 - verify(subscriptionManager, never()).add(anyString(), any()); - } } diff --git a/src/test/java/com/sofa/linkiving/domain/chat/service/RagChatServiceTest.java b/src/test/java/com/sofa/linkiving/domain/chat/service/RagChatServiceTest.java new file mode 100644 index 00000000..93547a9c --- /dev/null +++ b/src/test/java/com/sofa/linkiving/domain/chat/service/RagChatServiceTest.java @@ -0,0 +1,164 @@ +package com.sofa.linkiving.domain.chat.service; + +import static org.assertj.core.api.Assertions.*; +import static org.mockito.ArgumentMatchers.*; +import static org.mockito.BDDMockito.*; + +import java.util.Collections; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +import com.sofa.linkiving.domain.chat.ai.AnswerClient; +import com.sofa.linkiving.domain.chat.dto.request.RagAnswerReq; +import com.sofa.linkiving.domain.chat.dto.response.AnswerRes; +import com.sofa.linkiving.domain.chat.dto.response.RagAnswerRes; +import com.sofa.linkiving.domain.chat.entity.Chat; +import com.sofa.linkiving.domain.chat.entity.Message; +import com.sofa.linkiving.domain.link.dto.internal.LinkDto; +import com.sofa.linkiving.domain.link.entity.Link; +import com.sofa.linkiving.domain.link.service.LinkQueryService; +import com.sofa.linkiving.domain.member.entity.Member; + +@ExtendWith(MockitoExtension.class) +@DisplayName("RagChatService 단위 테스트") +public class RagChatServiceTest { + + @InjectMocks + private RagChatService ragChatService; + + @Mock + private AnswerClient answerClient; + + @Mock + private MessageCommandService messageCommandService; + + @Mock + private MessageQueryService messageQueryService; + + @Mock + private LinkQueryService linkQueryService; + + @Mock + private ChatQueryService chatQueryService; + + private Member member; + private Chat chat; + private Long chatId = 1L; + private String userMessage = "테스트 질문"; + + @BeforeEach + void setUp() { + member = mock(Member.class); + lenient().when(member.getId()).thenReturn(100L); + + chat = mock(Chat.class); + lenient().when(chat.getId()).thenReturn(chatId); + } + + @Test + @DisplayName(" 정상 흐름일 때 AI 응답을 처리하고 결과를 반환한다") + void shouldReturnAnswerRes_WhenProcessSuccessfully() throws ExecutionException, InterruptedException { + // given + // 1. Chat 조회 + given(chatQueryService.findChat(chatId, member)).willReturn(chat); + + // 2. 유저 메시지 저장 + Message questionMsg = mock(Message.class); + given(questionMsg.getId()).willReturn(50L); + given(messageCommandService.saveUserMessage(chat, userMessage)).willReturn(questionMsg); + + // 3. 과거 대화 내역 조회 + Message historyMsg = mock(Message.class); + given(historyMsg.getContent()).willReturn("이전 대화"); + given(messageQueryService.findTop7ByChatIdAndIdLessThanOrderByIdDesc(50L, chat)) + .willReturn(List.of(historyMsg)); + + // 4. AI Client 응답 설정 (유효한 링크 ID와 무효한 ID 혼합) + RagAnswerRes ragRes = new RagAnswerRes( + "AI 답변입니다.", + List.of("10", " invalid ", " 20 "), + List.of(new RagAnswerRes.ReasoningStep("생각 과정", List.of("10"))), + List.of("10", "20"), + false + ); + given(answerClient.generateAnswer(any(RagAnswerReq.class))).willReturn(ragRes); + + // 5. 링크 조회 + LinkDto linkDto1 = mock(LinkDto.class); + Link link1 = mock(Link.class); + given(linkDto1.link()).willReturn(link1); + + given(linkQueryService.findAllByIdInWithSummary(eq(List.of(10L, 20L)), eq(member))) + .willReturn(List.of(linkDto1)); + + // 6. AI 메시지 저장 + Message answerMsg = mock(Message.class); + given(answerMsg.getId()).willReturn(51L); + given(answerMsg.getContent()).willReturn("AI 답변입니다."); + given(messageCommandService.saveAiMessage(eq(chat), anyString(), anyList())) + .willReturn(answerMsg); + + // when + CompletableFuture future = ragChatService.generateAnswer(chatId, member, userMessage); + + // then + AnswerRes result = future.get(); + + assertThat(result).isNotNull(); + assertThat(result.chatId()).isEqualTo(chatId); + assertThat(result.content()).isEqualTo("AI 답변입니다."); + assertThat(result.links()).hasSize(1); + + // 순서대로 호출되었는지 검증 + verify(messageCommandService).saveUserMessage(chat, userMessage); + verify(answerClient).generateAnswer(any(RagAnswerReq.class)); + verify(linkQueryService).findAllByIdInWithSummary(eq(List.of(10L, 20L)), eq(member)); + verify(messageCommandService).saveAiMessage(chat, "AI 답변입니다.", List.of(link1)); + } + + @Test + @DisplayName("채팅방이 존재하지 않으면 예외 발생") + void shouldThrowException_WhenChatNotFound() { + // given + given(chatQueryService.findChat(chatId, member)) + .willThrow(new RuntimeException("Chat Not Found")); + + // when & then + assertThatThrownBy(() -> ragChatService.generateAnswer(chatId, member, userMessage)) + .isInstanceOf(RuntimeException.class) + .hasMessage("Chat Not Found"); + + verifyNoInteractions(answerClient); + } + + @Test + @DisplayName("AI 클라이언트 오류 발생 시 예외 전파") + void shouldThrowException_WhenAiClientFails() { + // given + given(chatQueryService.findChat(chatId, member)).willReturn(chat); + + Message questionMsg = mock(Message.class); + given(questionMsg.getId()).willReturn(50L); + given(messageCommandService.saveUserMessage(chat, userMessage)).willReturn(questionMsg); + + given(messageQueryService.findTop7ByChatIdAndIdLessThanOrderByIdDesc(anyLong(), any())) + .willReturn(Collections.emptyList()); + + given(answerClient.generateAnswer(any())) + .willThrow(new RuntimeException("AI Service Unavailable")); + + // when & then + assertThatThrownBy(() -> ragChatService.generateAnswer(chatId, member, userMessage)) + .isInstanceOf(RuntimeException.class) + .hasMessage("AI Service Unavailable"); + } +} diff --git a/src/test/java/com/sofa/linkiving/domain/link/repository/LinkRepositoryTest.java b/src/test/java/com/sofa/linkiving/domain/link/repository/LinkRepositoryTest.java index c1c7824f..6227ac5b 100644 --- a/src/test/java/com/sofa/linkiving/domain/link/repository/LinkRepositoryTest.java +++ b/src/test/java/com/sofa/linkiving/domain/link/repository/LinkRepositoryTest.java @@ -198,4 +198,76 @@ void shouldFindAllByMemberWithSummaryAndCursor() { assertThat(page2).hasSize(1); assertThat(page2.get(0).link().getTitle()).isEqualTo("링크 1"); } + + @Test + @DisplayName("조건에 맞는 링크와 선택된 요약을 조회한다") + void shouldReturnLinksWithSelectedSummary() { + // given + Member otherMember = Member.builder() + .email("other@test.com") + .password("password") + .build(); + entityManager.persist(otherMember); + + Link link1 = linkRepository.save(Link.builder() + .member(testMember) + .title("link1") + .url("http://url1.com") + .build()); + Summary summary1 = Summary.builder() + .link(link1) + .content("요약1") + .select(true) + .build(); + entityManager.persist(summary1); + + Link link2 = linkRepository.save(Link.builder() + .member(testMember) + .title("link2") + .url("http://url2.com") + .build()); + Summary summary2 = Summary.builder() + .link(link2) + .content("요약2") + .select(false) + .build(); + entityManager.persist(summary2); + + Link link3 = linkRepository.save(Link.builder() + .member(testMember) + .title("link3") + .url("http://url3.com") + .build()); + + Link link4 = linkRepository.save(Link.builder() + .member(otherMember) + .title("link4") + .url("http://url4.com") + .build()); + + List linkIds = List.of(link1.getId(), link2.getId(), link3.getId(), link4.getId()); + + // when + List result = linkRepository.findAllByMemberAndIdInWithSummaryAndIsDeleteFalse(linkIds, testMember); + + // then + assertThat(result).hasSize(3); + + List resultIds = result.stream().map(dto -> dto.link().getId()).toList(); + assertThat(resultIds).containsExactlyInAnyOrder(link1.getId(), link2.getId(), link3.getId()); + } + + @Test + @DisplayName("요청한 ID 목록에 해당하는 링크가 없으면 빈 리스트를 반환한다") + void shouldReturnEmptyList_WhenNoMatch() { + // given + List nonExistentIds = List.of(999L, 1000L); + + // when + List result = linkRepository.findAllByMemberAndIdInWithSummaryAndIsDeleteFalse(nonExistentIds, + testMember); + + // then + assertThat(result).isEmpty(); + } } diff --git a/src/test/java/com/sofa/linkiving/domain/link/service/LinkQueryServiceTest.java b/src/test/java/com/sofa/linkiving/domain/link/service/LinkQueryServiceTest.java index bcdae2d9..2432f057 100644 --- a/src/test/java/com/sofa/linkiving/domain/link/service/LinkQueryServiceTest.java +++ b/src/test/java/com/sofa/linkiving/domain/link/service/LinkQueryServiceTest.java @@ -187,4 +187,23 @@ void shouldThrowExceptionWhenLinkNotFoundInFindByIdWithSummary() { .isInstanceOf(BusinessException.class) .hasFieldOrPropertyWithValue("errorCode", LinkErrorCode.LINK_NOT_FOUND); } + + @Test + @DisplayName("리포지토리를 호출하여 링크 목록을 반환한다") + void shouldReturnLinks_WhenCalled() { + // given + Member member = mock(Member.class); + List linkIds = List.of(1L, 2L, 3L); + List expectedDtos = List.of(mock(LinkDto.class)); + + given(linkRepository.findAllByMemberAndIdInWithSummaryAndIsDeleteFalse(linkIds, member)) + .willReturn(expectedDtos); + + // when + List result = linkQueryService.findAllByIdInWithSummary(linkIds, member); + + // then + assertThat(result).isEqualTo(expectedDtos); + verify(linkRepository).findAllByMemberAndIdInWithSummaryAndIsDeleteFalse(linkIds, member); + } }