Skip to content

Commit 9031dc3

Browse files
authored
fix(rag): RAG-retrieved use USER role, AGENTIC RAG mode add missing retrieveConfig (#949)
## AgentScope-Java Version 1.0.11 ## Description * RAG-retrieved use USER role, Fixes: #941 * AGENTIC RAG mode add missing retrieveConfig, Fixes: #932 ## Checklist Please check the following items before code is ready to be reviewed. - [x] Code has been formatted with `mvn spotless:apply` - [x] All tests are passing (`mvn test`) - [x] Javadoc comments are complete and follow project conventions - [x] Related documentation has been updated (e.g. links, examples, etc.) - [x] Code is ready for review
1 parent be80cad commit 9031dc3

File tree

7 files changed

+129
-17
lines changed

7 files changed

+129
-17
lines changed

agentscope-core/src/main/java/io/agentscope/core/ReActAgent.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1439,7 +1439,7 @@ private void configureRAG(Toolkit agentToolkit) {
14391439
case AGENTIC -> {
14401440
// Register knowledge retrieval tools
14411441
KnowledgeRetrievalTools tools =
1442-
new KnowledgeRetrievalTools(aggregatedKnowledge);
1442+
new KnowledgeRetrievalTools(aggregatedKnowledge, retrieveConfig);
14431443
agentToolkit.registerTool(tools);
14441444
}
14451445
case NONE -> {

agentscope-core/src/main/java/io/agentscope/core/rag/GenericRAGHook.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
* <ol>
4242
* <li>Extracts the query from user messages</li>
4343
* <li>Retrieves relevant documents from the knowledge base</li>
44-
* <li>Injects the retrieved knowledge as a system message</li>
44+
* <li>Injects the retrieved knowledge as a user message</li>
4545
* <li>Modifies the input messages to include the knowledge context</li>
4646
* </ol>
4747
*
@@ -82,7 +82,7 @@ public class GenericRAGHook implements Hook {
8282
* @throws IllegalArgumentException if knowledgeBase is null
8383
*/
8484
public GenericRAGHook(Knowledge knowledge) {
85-
this(knowledge, RetrieveConfig.builder().limit(5).scoreThreshold(0.5).build());
85+
this(knowledge, RetrieveConfig.builder().build());
8686
}
8787

8888
/**
@@ -198,8 +198,8 @@ private Msg createEnhancedMessages(List<Document> retrievedDocs) {
198198
String knowledgeContent = buildKnowledgeContent(retrievedDocs);
199199

200200
return Msg.builder()
201-
.name("system")
202-
.role(MsgRole.SYSTEM)
201+
.name("user")
202+
.role(MsgRole.USER)
203203
.content(TextBlock.builder().text(knowledgeContent).build())
204204
.build();
205205
}

agentscope-core/src/main/java/io/agentscope/core/rag/KnowledgeRetrievalTools.java

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
* <p>Example usage:
3838
* <pre>{@code
3939
* KnowledgeBase knowledgeBase = new SimpleKnowledge(embeddingModel, vectorStore);
40-
* KnowledgeRetrievalTools tools = new KnowledgeRetrievalTools(knowledgeBase);
40+
* KnowledgeRetrievalTools tools = new KnowledgeRetrievalTools(knowledgeBase, RetrieveConfig.builder().build());
4141
*
4242
* Toolkit toolkit = new Toolkit();
4343
* toolkit.registerObject(tools);
@@ -53,17 +53,40 @@ public class KnowledgeRetrievalTools {
5353

5454
private final Knowledge knowledge;
5555

56+
private final RetrieveConfig defaultConfig;
57+
5658
/**
57-
* Creates a new KnowledgeRetrievalTools instance.
59+
* Creates a new KnowledgeRetrievalTools instance with default configuration.
60+
*
61+
* <p>Default configuration:
62+
* <ul>
63+
* <li>Limit: 5 documents</li>
64+
* <li>Score threshold: 0.5</li>
65+
* </ul>
5866
*
5967
* @param knowledge the knowledge base to retrieve from
6068
* @throws IllegalArgumentException if knowledgeBase is null
6169
*/
6270
public KnowledgeRetrievalTools(Knowledge knowledge) {
71+
this(knowledge, RetrieveConfig.builder().build());
72+
}
73+
74+
/**
75+
* Creates a new KnowledgeRetrievalTools instance.
76+
*
77+
* @param knowledge the knowledge base to retrieve from
78+
* @param defaultConfig the default retrieval configuration
79+
* @throws IllegalArgumentException if knowledgeBase is null
80+
*/
81+
public KnowledgeRetrievalTools(Knowledge knowledge, RetrieveConfig defaultConfig) {
6382
if (knowledge == null) {
6483
throw new IllegalArgumentException("Knowledge base cannot be null");
6584
}
85+
if (defaultConfig == null) {
86+
throw new IllegalArgumentException("Retrieve config cannot be null");
87+
}
6688
this.knowledge = knowledge;
89+
this.defaultConfig = defaultConfig;
6790
}
6891

6992
/**
@@ -121,9 +144,9 @@ public String retrieveKnowledge(
121144

122145
// Build retrieval config with conversation history
123146
RetrieveConfig config =
124-
RetrieveConfig.builder()
147+
this.defaultConfig
148+
.mutate()
125149
.limit(limit)
126-
.scoreThreshold(0.5)
127150
.conversationHistory(conversationHistory)
128151
.build();
129152

@@ -172,4 +195,13 @@ private String formatDocumentsForTool(List<Document> documents) {
172195
public Knowledge getKnowledgeBase() {
173196
return knowledge;
174197
}
198+
199+
/**
200+
* Gets the default retrieval configuration.
201+
*
202+
* @return the default config
203+
*/
204+
public RetrieveConfig getDefaultConfig() {
205+
return defaultConfig;
206+
}
175207
}

agentscope-core/src/main/java/io/agentscope/core/rag/model/RetrieveConfig.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,19 @@ public List<Msg> getConversationHistory() {
7878
return conversationHistory;
7979
}
8080

81+
/**
82+
* Mutate the current instance to a new builder.
83+
*
84+
* @return a new builder with the same values of this instance
85+
*/
86+
public Builder mutate() {
87+
return new Builder()
88+
.limit(this.limit)
89+
.scoreThreshold(this.scoreThreshold)
90+
.vectorName(this.vectorName)
91+
.conversationHistory(this.conversationHistory);
92+
}
93+
8194
/**
8295
* Creates a new builder instance.
8396
*

agentscope-core/src/test/java/io/agentscope/core/rag/model/RetrieveConfigTest.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
import static org.junit.jupiter.api.Assertions.assertEquals;
1919
import static org.junit.jupiter.api.Assertions.assertThrows;
2020

21+
import io.agentscope.core.message.Msg;
22+
import java.util.List;
2123
import org.junit.jupiter.api.DisplayName;
2224
import org.junit.jupiter.api.Tag;
2325
import org.junit.jupiter.api.Test;
@@ -101,4 +103,17 @@ void testBuilderChaining() {
101103
assertEquals(0.8, config.getScoreThreshold());
102104
assertEquals("test-vector", config.getVectorName());
103105
}
106+
107+
@Test
108+
@DisplayName("Should mutate RetrieveConfig with same values")
109+
void testMutate() {
110+
RetrieveConfig originConfig = RetrieveConfig.builder().build();
111+
List<Msg> conversationHistory = List.of(Msg.builder().textContent("test content").build());
112+
RetrieveConfig mutateConfig =
113+
originConfig.mutate().conversationHistory(conversationHistory).build();
114+
assertEquals(conversationHistory, mutateConfig.getConversationHistory());
115+
assertEquals(originConfig.getLimit(), mutateConfig.getLimit());
116+
assertEquals(originConfig.getScoreThreshold(), mutateConfig.getScoreThreshold());
117+
assertEquals(originConfig.getVectorName(), mutateConfig.getVectorName());
118+
}
104119
}

agentscope-extensions/agentscope-extensions-rag-simple/src/test/java/io/agentscope/core/rag/KnowledgeRetrievalToolsTest.java

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,13 @@
2222

2323
import io.agentscope.core.embedding.EmbeddingModel;
2424
import io.agentscope.core.message.ContentBlock;
25+
import io.agentscope.core.message.Msg;
2526
import io.agentscope.core.message.TextBlock;
2627
import io.agentscope.core.message.ToolResultBlock;
2728
import io.agentscope.core.rag.knowledge.SimpleKnowledge;
2829
import io.agentscope.core.rag.model.Document;
2930
import io.agentscope.core.rag.model.DocumentMetadata;
31+
import io.agentscope.core.rag.model.RetrieveConfig;
3032
import io.agentscope.core.rag.store.InMemoryStore;
3133
import io.agentscope.core.tool.AgentTool;
3234
import io.agentscope.core.tool.ToolCallParam;
@@ -51,6 +53,7 @@ class KnowledgeRetrievalToolsTest {
5153
private static final int DIMENSIONS = 3;
5254

5355
private Knowledge knowledge;
56+
private RetrieveConfig retrieveConfig;
5457
private KnowledgeRetrievalTools tools;
5558
private Toolkit toolkit;
5659

@@ -63,16 +66,29 @@ void setUp() {
6366
.embeddingModel(embeddingModel)
6467
.embeddingStore(vectorStore)
6568
.build();
66-
tools = new KnowledgeRetrievalTools(knowledge);
69+
retrieveConfig = RetrieveConfig.builder().build();
70+
tools = new KnowledgeRetrievalTools(knowledge, retrieveConfig);
6771
toolkit = new Toolkit();
6872
}
6973

7074
@Test
7175
@DisplayName("Should create KnowledgeRetrievalTools with valid knowledge base")
7276
void testCreate() {
77+
KnowledgeRetrievalTools newTools = new KnowledgeRetrievalTools(knowledge, retrieveConfig);
78+
assertNotNull(newTools);
79+
assertEquals(knowledge, newTools.getKnowledgeBase());
80+
assertEquals(retrieveConfig, newTools.getDefaultConfig());
81+
}
82+
83+
@Test
84+
@DisplayName("Should create KnowledgeRetrievalTools with default config")
85+
void testCreateWithDefaultConfig() {
7386
KnowledgeRetrievalTools newTools = new KnowledgeRetrievalTools(knowledge);
7487
assertNotNull(newTools);
7588
assertEquals(knowledge, newTools.getKnowledgeBase());
89+
assertNotNull(newTools.getDefaultConfig());
90+
assertEquals(5, newTools.getDefaultConfig().getLimit());
91+
assertEquals(0.5, newTools.getDefaultConfig().getScoreThreshold());
7692
}
7793

7894
@Test
@@ -81,6 +97,35 @@ void testCreateNullKnowledgeBase() {
8197
assertThrows(IllegalArgumentException.class, () -> new KnowledgeRetrievalTools(null));
8298
}
8399

100+
@Test
101+
@DisplayName("Should throw exception for null retrieve config")
102+
void testCreateNullConfig() {
103+
assertThrows(
104+
IllegalArgumentException.class, () -> new KnowledgeRetrievalTools(knowledge, null));
105+
}
106+
107+
@Test
108+
@DisplayName("Should create KnowledgeRetrievalTools with custom retrieve config")
109+
void testCreateCustomRetrieveConfig() {
110+
List<Msg> conversationHistory = List.of(Msg.builder().textContent("Hello").build());
111+
RetrieveConfig customRetrieveConfig =
112+
RetrieveConfig.builder()
113+
.limit(3)
114+
.scoreThreshold(0.9)
115+
.vectorName("test_vector")
116+
.conversationHistory(conversationHistory)
117+
.build();
118+
KnowledgeRetrievalTools newTools =
119+
new KnowledgeRetrievalTools(knowledge, customRetrieveConfig);
120+
assertNotNull(newTools);
121+
assertEquals(knowledge, newTools.getKnowledgeBase());
122+
assertEquals(customRetrieveConfig, newTools.getDefaultConfig());
123+
assertEquals(3, newTools.getDefaultConfig().getLimit());
124+
assertEquals(0.9, newTools.getDefaultConfig().getScoreThreshold());
125+
assertEquals("test_vector", newTools.getDefaultConfig().getVectorName());
126+
assertEquals(conversationHistory, newTools.getDefaultConfig().getConversationHistory());
127+
}
128+
84129
@Test
85130
@DisplayName("Should register tool with Toolkit")
86131
void testToolRegistration() {

agentscope-extensions/agentscope-extensions-rag-simple/src/test/java/io/agentscope/core/rag/hook/GenericRAGHookTest.java

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
package io.agentscope.core.rag.hook;
1717

1818
import static org.junit.jupiter.api.Assertions.assertEquals;
19+
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
1920
import static org.junit.jupiter.api.Assertions.assertNotNull;
2021
import static org.junit.jupiter.api.Assertions.assertThrows;
2122
import static org.junit.jupiter.api.Assertions.assertTrue;
@@ -101,6 +102,8 @@ void testCreateWithDefaults() {
101102
assertNotNull(newHook);
102103
assertEquals(knowledge, newHook.getKnowledgeBase());
103104
assertNotNull(newHook.getDefaultConfig());
105+
assertEquals(5, newHook.getDefaultConfig().getLimit());
106+
assertEquals(0.5, newHook.getDefaultConfig().getScoreThreshold());
104107
}
105108

106109
@Test
@@ -149,16 +152,20 @@ void testHandlePreCallEvent() {
149152
StepVerifier.create(hook.onEvent(event))
150153
.assertNext(
151154
result -> {
152-
assertTrue(result instanceof PreCallEvent);
153-
PreCallEvent preCallEvent = (PreCallEvent) result;
154-
List<Msg> enhancedMessages = preCallEvent.getInputMessages();
155+
assertInstanceOf(PreCallEvent.class, result);
156+
List<Msg> enhancedMessages = result.getInputMessages();
155157

156158
// Should have knowledge message + original message
157159
assertTrue(enhancedMessages.size() >= 2);
158-
// First message should be system message with knowledge
159-
Msg firstMsg = enhancedMessages.get(1);
160-
assertEquals(MsgRole.SYSTEM, firstMsg.getRole());
161-
assertTrue(firstMsg.getTextContent().contains("knowledge base"));
160+
// First message should be user message with question
161+
assertEquals(MsgRole.USER, enhancedMessages.get(0).getRole());
162+
// Second message should be user message with knowledge retrieval
163+
assertEquals(MsgRole.USER, enhancedMessages.get(1).getRole());
164+
assertTrue(
165+
enhancedMessages
166+
.get(1)
167+
.getTextContent()
168+
.contains("knowledge base"));
162169
})
163170
.verifyComplete();
164171
}

0 commit comments

Comments
 (0)