Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import java.util.List;

import org.springframework.context.annotation.Primary;
import org.springframework.context.annotation.Profile;
import org.springframework.stereotype.Component;

import com.sofa.linkiving.domain.chat.dto.request.RagAnswerReq;
Expand All @@ -12,7 +12,7 @@

@Slf4j
@Component
@Primary
@Profile("test")
public class MockAnswerClient implements AnswerClient {

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package com.sofa.linkiving.domain.chat.ai;

import java.util.List;

import org.springframework.context.annotation.Profile;
import org.springframework.stereotype.Component;

import com.sofa.linkiving.domain.chat.dto.request.RagAnswerReq;
import com.sofa.linkiving.domain.chat.dto.response.RagAnswerRes;

import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;

@Slf4j
@Component
@Profile("!test")
@RequiredArgsConstructor
public class RagAnswerClient implements AnswerClient {

private final RagAnswerFeign ragAnswerFeign;

@Override
public RagAnswerRes generateAnswer(RagAnswerReq request) {
try {
List<RagAnswerRes> ragAnswerRes = ragAnswerFeign.generateAnswer(request);
log.info("RagAnswerClient generateAnswer ragAnswerRes={}", ragAnswerRes);
return ragAnswerRes.get(0);
} catch (Exception e) {
log.error("RagAnswerClient generateAnswer error", e);
return null;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package com.sofa.linkiving.domain.chat.ai;

import java.util.List;

import org.springframework.cloud.openfeign.FeignClient;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;

import com.sofa.linkiving.domain.chat.dto.request.RagAnswerReq;
import com.sofa.linkiving.domain.chat.dto.response.RagAnswerRes;
import com.sofa.linkiving.infra.feign.GlobalFeignConfig;

@FeignClient(name = "ai-answer-client", url = "${ai.server.url}", configuration = GlobalFeignConfig.class)
public interface RagAnswerFeign {
@PostMapping("/webhook/chat-answer")
List<RagAnswerRes> generateAnswer(@RequestBody RagAnswerReq request);
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,31 @@

import java.util.List;

import com.sofa.linkiving.domain.chat.entity.Message;
import com.sofa.linkiving.domain.chat.enums.Mode;
import com.sofa.linkiving.domain.chat.enums.Type;

public record RagAnswerReq(
Long userId,
String question,
List<String> history,
List<RagMessageReq> history,
Mode mode
) {
public static RagAnswerReq of(Long userId, String question, List<Message> messages, Mode mode) {
List<RagMessageReq> history = messages.stream()
.map(RagMessageReq::from)
.toList();

return new RagAnswerReq(userId, question, history, mode);
}

public record RagMessageReq(
String role,
String content
) {
public static RagMessageReq from(Message message) {
String role = (message.getType() == Type.AI) ? "system" : "user";
return new RagMessageReq(role, message.getContent());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,24 @@ public void generateAnswer(Long chatId, Member member, String message) {
taskManager.remove(chatId);

if (task.isCancelled() || ex != null) {

if (ex != null) {
log.error("AI 답변 생성 중 오류 발생 - chatId: {}, error: {}", chatId, ex.getMessage(), ex);
} else {
log.info("AI 답변 생성 작업 취소됨 - chatId: {}", chatId);
}

sendNotification(chatId, member, AnswerRes.error(chatId, message));
return;
}

if (result != null) {
sendNotification(chatId, member, result);
return;
}

log.error("AI 답변 생성 결과가 null 입니다 - chatId: {}", chatId);
sendNotification(chatId, member, AnswerRes.error(chatId, message));
});
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package com.sofa.linkiving.domain.chat.service;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
Expand Down Expand Up @@ -43,15 +42,11 @@ public CompletableFuture<AnswerRes> generateAnswer(Long chatId, Member member, S
Chat chat = chatQueryService.findChat(chatId, member);

Message question = messageCommandService.saveUserMessage(chat, userMessage);
List<Message> pastMessages = messageQueryService.findTop7ByChatIdAndIdLessThanOrderByIdDesc(
List<Message> history = messageQueryService.findTop7ByChatIdAndIdLessThanOrderByIdDesc(
question.getId(), chat);

List<String> history = new ArrayList<>(pastMessages.stream()
.map(Message::getContent)
.toList());
Collections.reverse(history);

RagAnswerReq request = new RagAnswerReq(
RagAnswerReq request = RagAnswerReq.of(
member.getId(),
userMessage,
history,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package com.sofa.linkiving.domain.chat.ai;

import static org.assertj.core.api.Assertions.*;
import static org.mockito.ArgumentMatchers.*;
import static org.mockito.BDDMockito.*;

import java.util.List;

import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;

import com.sofa.linkiving.domain.chat.dto.request.RagAnswerReq;
import com.sofa.linkiving.domain.chat.dto.response.RagAnswerRes;

@ExtendWith(MockitoExtension.class)
@DisplayName("RagAnswerClient 단위 테스트")
class RagAnswerClientTest {

@InjectMocks
private RagAnswerClient ragAnswerClient;

@Mock
private RagAnswerFeign ragAnswerFeign;

@Test
@DisplayName("generateAnswer: Feign 응답이 정상일 경우 리스트의 첫 번째 요소를 반환한다")
void shouldReturnFirstElement_WhenGenerateAnswerSuccess() {
// given
RagAnswerReq req = mock(RagAnswerReq.class);
RagAnswerRes expectedRes = mock(RagAnswerRes.class);
given(ragAnswerFeign.generateAnswer(any(RagAnswerReq.class)))
.willReturn(List.of(expectedRes));

// when
RagAnswerRes actualRes = ragAnswerClient.generateAnswer(req);

// then
assertThat(actualRes).isEqualTo(expectedRes);
}

@Test
@DisplayName("generateAnswer: Feign 요청 중 예외가 발생하면 예외를 잡고 null을 반환한다")
void shouldCatchExceptionAndReturnNull_WhenGenerateAnswerThrowsException() {
// given
RagAnswerReq req = mock(RagAnswerReq.class);
given(ragAnswerFeign.generateAnswer(any(RagAnswerReq.class)))
.willThrow(new RuntimeException("AI Server Error"));

// when
RagAnswerRes actualRes = ragAnswerClient.generateAnswer(req);

// then
assertThat(actualRes).isNull();
}
}