@@ -18,7 +18,11 @@ import {
18
18
} from "@langchain/langgraph" ;
19
19
import { MemorySaver } from "@langchain/langgraph-checkpoint" ;
20
20
import { FakeStreamingChatModel } from "@langchain/core/utils/testing" ;
21
- import { AIMessage } from "@langchain/core/messages" ;
21
+ import {
22
+ AIMessage ,
23
+ BaseMessage ,
24
+ RemoveMessage ,
25
+ } from "@langchain/core/messages" ;
22
26
import {
23
27
createEmbedServer ,
24
28
type ThreadSaver ,
@@ -85,8 +89,47 @@ const interruptAgent = new StateGraph(MessagesAnnotation)
85
89
. addEdge ( "afterInterrupt" , END )
86
90
. compile ( ) ;
87
91
92
+ const removeMessageAgent = new StateGraph ( MessagesAnnotation )
93
+ . addSequence ( {
94
+ step1 : ( ) => ( { messages : [ new AIMessage ( "Step 1: To Remove" ) ] } ) ,
95
+ step2 : async ( state , config ) => {
96
+ // Send message before persisting to state
97
+ // TODO: replace with `pushMessage` when part of 1.x
98
+ const messages : BaseMessage [ ] = [
99
+ ...state . messages
100
+ . filter ( ( m ) => m . getType ( ) === "ai" )
101
+ . map ( ( m ) => new RemoveMessage ( { id : m . id ! } ) ) ,
102
+ new AIMessage ( { id : randomUUID ( ) , content : "Step 2: To Keep" } ) ,
103
+ ] ;
104
+
105
+ const messagesHandler = (
106
+ config . callbacks as { handlers : object [ ] }
107
+ ) ?. handlers ?. find (
108
+ (
109
+ cb
110
+ ) : cb is {
111
+ _emit : (
112
+ chunk : [ namespace : string [ ] , metadata : Record < string , unknown > ] ,
113
+ message : BaseMessage ,
114
+ runId : string | undefined ,
115
+ dedupe : boolean
116
+ ) => void ;
117
+ } => "name" in cb && cb . name === "StreamMessagesHandler"
118
+ ) ;
119
+
120
+ for ( const message of messages ) {
121
+ messagesHandler ?. _emit ( [ [ ] , { } ] , message , undefined , false ) ;
122
+ }
123
+
124
+ return { messages } ;
125
+ } ,
126
+ step3 : ( ) => ( { messages : [ new AIMessage ( "Step 3: To Keep" ) ] } ) ,
127
+ } )
128
+ . addEdge ( START , "step1" )
129
+ . compile ( ) ;
130
+
88
131
const app = createEmbedServer ( {
89
- graph : { agent, parentAgent, interruptAgent } ,
132
+ graph : { agent, parentAgent, interruptAgent, removeMessageAgent } ,
90
133
checkpointer,
91
134
threads,
92
135
} ) ;
@@ -1194,4 +1237,75 @@ describe("useStream", () => {
1194
1237
} ) ;
1195
1238
}
1196
1239
) ;
1240
+
1241
+ it ( "handle message removal" , async ( ) => {
1242
+ const user = userEvent . setup ( ) ;
1243
+ const messagesValues = new Set < string > ( ) ;
1244
+
1245
+ function TestComponent ( ) {
1246
+ const { submit, messages, isLoading } = useStream ( {
1247
+ assistantId : "removeMessageAgent" ,
1248
+ apiKey : "test-api-key" ,
1249
+ } ) ;
1250
+
1251
+ const rawMessages = messages . map ( ( msg , i ) => ( {
1252
+ id : msg . id ?? i ,
1253
+ content : `${ msg . type } : ${
1254
+ typeof msg . content === "string"
1255
+ ? msg . content
1256
+ : JSON . stringify ( msg . content )
1257
+ } `,
1258
+ } ) ) ;
1259
+
1260
+ messagesValues . add ( rawMessages . map ( ( msg ) => msg . content ) . join ( "\n" ) ) ;
1261
+
1262
+ return (
1263
+ < div >
1264
+ < div data-testid = "loading" >
1265
+ { isLoading ? "Loading..." : "Not loading" }
1266
+ </ div >
1267
+ < div data-testid = "messages" >
1268
+ { rawMessages . map ( ( msg , i ) => (
1269
+ < div key = { msg . id } data-testid = { `message-${ i } ` } >
1270
+ < span > { msg . content } </ span >
1271
+ </ div >
1272
+ ) ) }
1273
+ </ div >
1274
+ < button
1275
+ data-testid = "submit"
1276
+ onClick = { ( ) =>
1277
+ submit ( { messages : [ { content : "Hello" , type : "human" } ] } )
1278
+ }
1279
+ >
1280
+ Send
1281
+ </ button >
1282
+ </ div >
1283
+ ) ;
1284
+ }
1285
+
1286
+ render ( < TestComponent /> ) ;
1287
+
1288
+ await user . click ( screen . getByTestId ( "submit" ) ) ;
1289
+
1290
+ await waitFor ( ( ) => {
1291
+ expect ( screen . getByTestId ( "loading" ) ) . toHaveTextContent ( "Not loading" ) ;
1292
+ expect ( screen . getByTestId ( "message-0" ) ) . toHaveTextContent ( "human: Hello" ) ;
1293
+ expect ( screen . getByTestId ( "message-1" ) ) . toHaveTextContent (
1294
+ "ai: Step 2: To Keep"
1295
+ ) ;
1296
+ expect ( screen . getByTestId ( "message-2" ) ) . toHaveTextContent (
1297
+ "ai: Step 3: To Keep"
1298
+ ) ;
1299
+ } ) ;
1300
+
1301
+ expect ( [ ...messagesValues . values ( ) ] ) . toMatchObject (
1302
+ [
1303
+ [ ] ,
1304
+ [ "human: Hello" ] ,
1305
+ [ "human: Hello" , "ai: Step 1: To Remove" ] ,
1306
+ [ "human: Hello" , "ai: Step 2: To Keep" ] ,
1307
+ [ "human: Hello" , "ai: Step 2: To Keep" , "ai: Step 3: To Keep" ] ,
1308
+ ] . map ( ( msg ) => msg . join ( "\n" ) )
1309
+ ) ;
1310
+ } ) ;
1197
1311
} ) ;
0 commit comments