Skip to content

Commit e6c6bbf

Browse files
hntrlchristian-bromann
authored andcommitted
fix(langchain/middleware): allow model strings in summarization middleware
1 parent 7a4a385 commit e6c6bbf

File tree

3 files changed

+66
-2
lines changed

3 files changed

+66
-2
lines changed

.changeset/chatty-spies-shine.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"langchain": patch
3+
---
4+
5+
allow for model strings in summarization middleware

libs/langchain/src/agents/middleware/summarization.ts

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import { REMOVE_ALL_MESSAGES } from "@langchain/langgraph";
1818
import { createMiddleware } from "../middleware.js";
1919
import { countTokensApproximately } from "./utils.js";
2020
import { hasToolCalls } from "../utils.js";
21+
import { initChatModel } from "../../chat_models/universal.js";
2122

2223
const DEFAULT_SUMMARY_PROMPT = `<role>
2324
Context Extraction Assistant
@@ -57,7 +58,7 @@ const SEARCH_RANGE_FOR_TOOL_PAIRS = 5;
5758
type TokenCounter = (messages: BaseMessage[]) => number | Promise<number>;
5859

5960
const contextSchema = z.object({
60-
model: z.custom<BaseLanguageModel>(),
61+
model: z.custom<string | BaseLanguageModel>(),
6162
maxTokensBeforeSummary: z.number().optional(),
6263
messagesToKeep: z.number().default(DEFAULT_MESSAGES_TO_KEEP),
6364
tokenCounter: z
@@ -148,6 +149,11 @@ export function summarizationMiddleware(
148149
} as InferInteropZodOutput<typeof contextSchema>;
149150
const { messages } = state;
150151

152+
const model =
153+
typeof config.model === "string"
154+
? await initChatModel(config.model)
155+
: config.model;
156+
151157
// Ensure all messages have IDs
152158
ensureMessageIds(messages);
153159

@@ -180,7 +186,7 @@ export function summarizationMiddleware(
180186

181187
const summary = await createSummary(
182188
messagesToSummarize,
183-
config.model,
189+
model,
184190
config.summaryPrompt,
185191
tokenCounter
186192
);

libs/langchain/src/agents/middleware/tests/summarization.test.ts

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,36 @@ import { countTokensApproximately } from "../utils.js";
1111
import { createAgent } from "../../index.js";
1212
import { FakeToolCallingChatModel } from "../../tests/utils.js";
1313

14+
// Mock @langchain/anthropic to test model string usage without requiring the built package
15+
vi.mock("@langchain/anthropic", async () => {
16+
const { AIMessage } = await import("@langchain/core/messages");
17+
return {
18+
ChatAnthropic: class MockChatAnthropic {
19+
lc_kwargs: Record<string, any>;
20+
21+
constructor(params?: any) {
22+
this.lc_kwargs = params || {};
23+
}
24+
25+
async invoke() {
26+
return new AIMessage({ content: "Mocked response" });
27+
}
28+
29+
getName() {
30+
return "ChatAnthropic";
31+
}
32+
33+
get _modelType() {
34+
return "chat-anthropic";
35+
}
36+
37+
get lc_runnable() {
38+
return true;
39+
}
40+
},
41+
};
42+
});
43+
1444
describe("summarizationMiddleware", () => {
1545
// Mock summarization model
1646
function createMockSummarizationModel() {
@@ -335,4 +365,27 @@ describe("summarizationMiddleware", () => {
335365
expect(nonSystemMessages.length).toBeGreaterThanOrEqual(messagesToKeep);
336366
expect(nonSystemMessages.length).toBeLessThanOrEqual(messagesToKeep + 3); // Some buffer for safety
337367
});
368+
369+
it("can be created using a model string", async () => {
370+
// Verify the mocked ChatAnthropic exists
371+
const { ChatAnthropic } = await import("@langchain/anthropic");
372+
expect(ChatAnthropic).toBeDefined();
373+
expect(typeof ChatAnthropic).toBe("function");
374+
375+
const model = "anthropic:claude-sonnet-4-20250514";
376+
377+
const middleware = summarizationMiddleware({
378+
model,
379+
maxTokensBeforeSummary: 100,
380+
messagesToKeep: 2,
381+
});
382+
383+
const agent = createAgent({
384+
model,
385+
middleware: [middleware],
386+
});
387+
388+
const result = await agent.invoke({ messages: [] });
389+
expect(result.messages.length).toBeGreaterThan(0);
390+
});
338391
});

0 commit comments

Comments
 (0)