Skip to content

Commit 599a8c5

Browse files
authored
fix(sdk): handle RemoveMessage when received in messages-tuple stream (#1656)
1 parent b56d6e3 commit 599a8c5

File tree

4 files changed

+138
-4
lines changed

4 files changed

+138
-4
lines changed

.changeset/itchy-lights-float.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"@langchain/langgraph-sdk": patch
3+
---
4+
5+
Add support for streaming of RemoveMessage in useStream

libs/sdk-validation/src/tests/stream.test.tsx

Lines changed: 116 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,11 @@ import {
1818
} from "@langchain/langgraph";
1919
import { MemorySaver } from "@langchain/langgraph-checkpoint";
2020
import { FakeStreamingChatModel } from "@langchain/core/utils/testing";
21-
import { AIMessage } from "@langchain/core/messages";
21+
import {
22+
AIMessage,
23+
BaseMessage,
24+
RemoveMessage,
25+
} from "@langchain/core/messages";
2226
import {
2327
createEmbedServer,
2428
type ThreadSaver,
@@ -85,8 +89,47 @@ const interruptAgent = new StateGraph(MessagesAnnotation)
8589
.addEdge("afterInterrupt", END)
8690
.compile();
8791

92+
const removeMessageAgent = new StateGraph(MessagesAnnotation)
93+
.addSequence({
94+
step1: () => ({ messages: [new AIMessage("Step 1: To Remove")] }),
95+
step2: async (state, config) => {
96+
// Send message before persisting to state
97+
// TODO: replace with `pushMessage` when part of 1.x
98+
const messages: BaseMessage[] = [
99+
...state.messages
100+
.filter((m) => m.getType() === "ai")
101+
.map((m) => new RemoveMessage({ id: m.id! })),
102+
new AIMessage({ id: randomUUID(), content: "Step 2: To Keep" }),
103+
];
104+
105+
const messagesHandler = (
106+
config.callbacks as { handlers: object[] }
107+
)?.handlers?.find(
108+
(
109+
cb
110+
): cb is {
111+
_emit: (
112+
chunk: [namespace: string[], metadata: Record<string, unknown>],
113+
message: BaseMessage,
114+
runId: string | undefined,
115+
dedupe: boolean
116+
) => void;
117+
} => "name" in cb && cb.name === "StreamMessagesHandler"
118+
);
119+
120+
for (const message of messages) {
121+
messagesHandler?._emit([[], {}], message, undefined, false);
122+
}
123+
124+
return { messages };
125+
},
126+
step3: () => ({ messages: [new AIMessage("Step 3: To Keep")] }),
127+
})
128+
.addEdge(START, "step1")
129+
.compile();
130+
88131
const app = createEmbedServer({
89-
graph: { agent, parentAgent, interruptAgent },
132+
graph: { agent, parentAgent, interruptAgent, removeMessageAgent },
90133
checkpointer,
91134
threads,
92135
});
@@ -1194,4 +1237,75 @@ describe("useStream", () => {
11941237
});
11951238
}
11961239
);
1240+
1241+
it("handle message removal", async () => {
1242+
const user = userEvent.setup();
1243+
const messagesValues = new Set<string>();
1244+
1245+
function TestComponent() {
1246+
const { submit, messages, isLoading } = useStream({
1247+
assistantId: "removeMessageAgent",
1248+
apiKey: "test-api-key",
1249+
});
1250+
1251+
const rawMessages = messages.map((msg, i) => ({
1252+
id: msg.id ?? i,
1253+
content: `${msg.type}: ${
1254+
typeof msg.content === "string"
1255+
? msg.content
1256+
: JSON.stringify(msg.content)
1257+
}`,
1258+
}));
1259+
1260+
messagesValues.add(rawMessages.map((msg) => msg.content).join("\n"));
1261+
1262+
return (
1263+
<div>
1264+
<div data-testid="loading">
1265+
{isLoading ? "Loading..." : "Not loading"}
1266+
</div>
1267+
<div data-testid="messages">
1268+
{rawMessages.map((msg, i) => (
1269+
<div key={msg.id} data-testid={`message-${i}`}>
1270+
<span>{msg.content}</span>
1271+
</div>
1272+
))}
1273+
</div>
1274+
<button
1275+
data-testid="submit"
1276+
onClick={() =>
1277+
submit({ messages: [{ content: "Hello", type: "human" }] })
1278+
}
1279+
>
1280+
Send
1281+
</button>
1282+
</div>
1283+
);
1284+
}
1285+
1286+
render(<TestComponent />);
1287+
1288+
await user.click(screen.getByTestId("submit"));
1289+
1290+
await waitFor(() => {
1291+
expect(screen.getByTestId("loading")).toHaveTextContent("Not loading");
1292+
expect(screen.getByTestId("message-0")).toHaveTextContent("human: Hello");
1293+
expect(screen.getByTestId("message-1")).toHaveTextContent(
1294+
"ai: Step 2: To Keep"
1295+
);
1296+
expect(screen.getByTestId("message-2")).toHaveTextContent(
1297+
"ai: Step 3: To Keep"
1298+
);
1299+
});
1300+
1301+
expect([...messagesValues.values()]).toMatchObject(
1302+
[
1303+
[],
1304+
["human: Hello"],
1305+
["human: Hello", "ai: Step 1: To Remove"],
1306+
["human: Hello", "ai: Step 2: To Keep"],
1307+
["human: Hello", "ai: Step 2: To Keep", "ai: Step 3: To Keep"],
1308+
].map((msg) => msg.join("\n"))
1309+
);
1310+
});
11971311
});

libs/sdk/src/react/manager.ts

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,11 @@ export class StreamManager<
298298
this.messages.get(messageId, messages.length) ?? {};
299299

300300
if (!chunk || index == null) return values;
301-
messages[index] = toMessageDict(chunk);
301+
if (chunk.getType() === "remove") {
302+
messages.splice(index, 1);
303+
} else {
304+
messages[index] = toMessageDict(chunk);
305+
}
302306

303307
return options.setMessages(values, messages);
304308
});

libs/sdk/src/react/messages.ts

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import {
22
type BaseMessage,
33
type BaseMessageChunk,
4+
RemoveMessage,
45
convertToChunk,
56
coerceMessageLikeToMessage,
67
isBaseMessageChunk,
@@ -15,6 +16,16 @@ function tryConvertToChunk(message: BaseMessage): BaseMessageChunk | null {
1516
return null;
1617
}
1718
}
19+
20+
function tryCoerceMessageLikeToMessage(message: Message): BaseMessage {
21+
// TODO: this is unnecessary with https://github.com/langchain-ai/langchainjs/pull/8941
22+
if (message.type === "remove" && message.id != null) {
23+
return new RemoveMessage({ ...message, id: message.id });
24+
}
25+
26+
return coerceMessageLikeToMessage(message);
27+
}
28+
1829
export class MessageTupleManager {
1930
chunks: Record<
2031
string,
@@ -42,7 +53,7 @@ export class MessageTupleManager {
4253
.toLowerCase() as Message["type"];
4354
}
4455

45-
const message = coerceMessageLikeToMessage(serialized);
56+
const message = tryCoerceMessageLikeToMessage(serialized);
4657
const chunk = tryConvertToChunk(message);
4758

4859
const { id } = chunk ?? message;

0 commit comments

Comments
 (0)