Skip to content

Commit d2a4b22

Browse files
committed
GH-4648: Added reactive approach to method toolbacks, allowing to backpressure down until the tool definition, and to pass context to them.
Signed-off-by: Björn Cersowsky <[email protected]> Signed-off-by: Björn Cersowsky <[email protected]>
1 parent ffe11b4 commit d2a4b22

File tree

2 files changed

+56
-47
lines changed

2 files changed

+56
-47
lines changed

mcp/common/src/main/java/org/springframework/ai/mcp/McpToolUtils.java

Lines changed: 40 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@
2222
import java.util.function.BiFunction;
2323
import java.util.stream.Stream;
2424

25-
import com.fasterxml.jackson.annotation.JsonAlias;
26-
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
2725
import io.micrometer.common.util.StringUtils;
2826
import io.modelcontextprotocol.client.McpAsyncClient;
2927
import io.modelcontextprotocol.client.McpSyncClient;
@@ -39,7 +37,7 @@
3937
import org.springframework.ai.chat.model.ToolContext;
4038
import org.springframework.ai.model.ModelOptionsUtils;
4139
import org.springframework.ai.tool.ToolCallback;
42-
import org.springframework.lang.Nullable;
40+
import org.springframework.ai.tool.method.MethodToolCallback;
4341
import org.springframework.util.CollectionUtils;
4442
import org.springframework.util.MimeType;
4543

@@ -196,12 +194,13 @@ public static McpServerFeatures.SyncToolSpecification toSyncToolSpecification(To
196194
public static McpServerFeatures.SyncToolSpecification toSyncToolSpecification(ToolCallback toolCallback,
197195
MimeType mimeType) {
198196

199-
SharedSyncToolSpecification sharedSpec = toSharedSyncToolSpecification(toolCallback, mimeType);
197+
SharedAsyncToolSpecification sharedSpec = toSharedAsyncToolSpecification(toolCallback, mimeType);
200198

201199
return new McpServerFeatures.SyncToolSpecification(sharedSpec.tool(),
202200
(exchange, map) -> sharedSpec.sharedHandler()
203-
.apply(exchange, new CallToolRequest(sharedSpec.tool().name(), map)),
204-
(exchange, request) -> sharedSpec.sharedHandler().apply(exchange, request));
201+
.apply(exchange, new CallToolRequest(sharedSpec.tool().name(), map))
202+
.block(),
203+
(exchange, request) -> sharedSpec.sharedHandler().apply(exchange, request).block());
205204
}
206205

207206
/**
@@ -219,15 +218,15 @@ public static McpServerFeatures.SyncToolSpecification toSyncToolSpecification(To
219218
public static McpStatelessServerFeatures.SyncToolSpecification toStatelessSyncToolSpecification(
220219
ToolCallback toolCallback, MimeType mimeType) {
221220

222-
var sharedSpec = toSharedSyncToolSpecification(toolCallback, mimeType);
221+
var sharedSpec = toSharedAsyncToolSpecification(toolCallback, mimeType);
223222

224223
return McpStatelessServerFeatures.SyncToolSpecification.builder()
225224
.tool(sharedSpec.tool())
226-
.callHandler((exchange, request) -> sharedSpec.sharedHandler().apply(exchange, request))
225+
.callHandler((exchange, request) -> sharedSpec.sharedHandler().apply(exchange, request).block())
227226
.build();
228227
}
229228

230-
private static SharedSyncToolSpecification toSharedSyncToolSpecification(ToolCallback toolCallback,
229+
private static SharedAsyncToolSpecification toSharedAsyncToolSpecification(ToolCallback toolCallback,
231230
MimeType mimeType) {
232231

233232
var tool = McpSchema.Tool.builder()
@@ -237,20 +236,31 @@ private static SharedSyncToolSpecification toSharedSyncToolSpecification(ToolCal
237236
McpSchema.JsonSchema.class))
238237
.build();
239238

240-
return new SharedSyncToolSpecification(tool, (exchangeOrContext, request) -> {
241-
try {
242-
String callResult = toolCallback.call(ModelOptionsUtils.toJsonString(request.arguments()),
243-
new ToolContext(Map.of(TOOL_CONTEXT_MCP_EXCHANGE_KEY, exchangeOrContext)));
239+
return new SharedAsyncToolSpecification(tool, (exchangeOrContext, request) -> {
240+
final String toolRequest = ModelOptionsUtils.toJsonString(request.arguments());
241+
final ToolContext toolContext = new ToolContext(Map.of(TOOL_CONTEXT_MCP_EXCHANGE_KEY, exchangeOrContext));
242+
final Mono<String> callResult;
243+
if (toolCallback instanceof MethodToolCallback reactiveMethodToolCallback) {
244+
callResult = reactiveMethodToolCallback.callReactive(toolRequest, toolContext);
245+
}
246+
else {
247+
callResult = Mono.fromCallable(() -> toolCallback.call(toolRequest, toolContext));
248+
}
249+
return callResult.map(result -> {
244250
if (mimeType != null && mimeType.toString().startsWith("image")) {
245251
McpSchema.Annotations annotations = new McpSchema.Annotations(List.of(Role.ASSISTANT), null);
246-
return new McpSchema.CallToolResult(
247-
List.of(new McpSchema.ImageContent(annotations, callResult, mimeType.toString())), false);
252+
return McpSchema.CallToolResult.builder()
253+
.addContent(new McpSchema.ImageContent(annotations, result, mimeType.toString()))
254+
.isError(false)
255+
.build();
248256
}
249-
return new McpSchema.CallToolResult(List.of(new McpSchema.TextContent(callResult)), false);
250-
}
251-
catch (Exception e) {
252-
return new McpSchema.CallToolResult(List.of(new McpSchema.TextContent(e.getMessage())), true);
253-
}
257+
return McpSchema.CallToolResult.builder().addTextContent(result).isError(false).build();
258+
})
259+
.onErrorResume(Exception.class,
260+
error -> Mono.fromSupplier(() -> McpSchema.CallToolResult.builder()
261+
.addTextContent(error.getMessage())
262+
.isError(true)
263+
.build()));
254264
});
255265
}
256266

@@ -331,7 +341,6 @@ public static McpServerFeatures.AsyncToolSpecification toAsyncToolSpecification(
331341
* This method enables Spring AI tools to be exposed as asynchronous MCP tools that
332342
* can be discovered and invoked by language models. The conversion process:
333343
* <ul>
334-
* <li>First converts the callback to a synchronous specification</li>
335344
* <li>Wraps the synchronous execution in a reactive Mono</li>
336345
* <li>Configures execution on a bounded elastic scheduler for non-blocking
337346
* operation</li>
@@ -352,26 +361,24 @@ public static McpServerFeatures.AsyncToolSpecification toAsyncToolSpecification(
352361
public static McpServerFeatures.AsyncToolSpecification toAsyncToolSpecification(ToolCallback toolCallback,
353362
MimeType mimeType) {
354363

355-
McpServerFeatures.SyncToolSpecification syncToolSpecification = toSyncToolSpecification(toolCallback, mimeType);
364+
SharedAsyncToolSpecification asyncToolSpecification = toSharedAsyncToolSpecification(toolCallback, mimeType);
356365

357366
return McpServerFeatures.AsyncToolSpecification.builder()
358-
.tool(syncToolSpecification.tool())
359-
.callHandler((exchange, request) -> Mono
360-
.fromCallable(
361-
() -> syncToolSpecification.callHandler().apply(new McpSyncServerExchange(exchange), request))
367+
.tool(asyncToolSpecification.tool())
368+
.callHandler((exchange, request) -> asyncToolSpecification.sharedHandler()
369+
.apply(new McpSyncServerExchange(exchange), request)
362370
.subscribeOn(Schedulers.boundedElastic()))
363371
.build();
364372
}
365373

366374
public static McpStatelessServerFeatures.AsyncToolSpecification toStatelessAsyncToolSpecification(
367375
ToolCallback toolCallback, MimeType mimeType) {
368376

369-
McpStatelessServerFeatures.SyncToolSpecification statelessSyncToolSpecification = toStatelessSyncToolSpecification(
370-
toolCallback, mimeType);
377+
SharedAsyncToolSpecification asyncToolSpecification = toSharedAsyncToolSpecification(toolCallback, mimeType);
371378

372-
return new McpStatelessServerFeatures.AsyncToolSpecification(statelessSyncToolSpecification.tool(),
373-
(context, request) -> Mono
374-
.fromCallable(() -> statelessSyncToolSpecification.callHandler().apply(context, request))
379+
return new McpStatelessServerFeatures.AsyncToolSpecification(asyncToolSpecification.tool(),
380+
(context, request) -> asyncToolSpecification.sharedHandler()
381+
.apply(context, request)
375382
.subscribeOn(Schedulers.boundedElastic()));
376383
}
377384

@@ -441,13 +448,8 @@ public static List<ToolCallback> getToolCallbacksFromAsyncClients(List<McpAsyncC
441448
return List.of((AsyncMcpToolCallbackProvider.builder().mcpClients(asyncMcpClients).build().getToolCallbacks()));
442449
}
443450

444-
@JsonIgnoreProperties(ignoreUnknown = true)
445-
// @formatter:off
446-
private record Base64Wrapper(@JsonAlias("mimetype") @Nullable MimeType mimeType, @JsonAlias({
447-
"base64", "b64", "imageData" }) @Nullable String data) {
451+
private record SharedAsyncToolSpecification(McpSchema.Tool tool,
452+
BiFunction<Object, CallToolRequest, Mono<McpSchema.CallToolResult>> sharedHandler) {
448453
}
449454

450-
private record SharedSyncToolSpecification(McpSchema.Tool tool,
451-
BiFunction<Object, CallToolRequest, McpSchema.CallToolResult> sharedHandler) {
452-
}
453455
}

spring-ai-model/src/main/java/org/springframework/ai/tool/method/MethodToolCallback.java

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,10 @@
2424
import java.util.stream.Stream;
2525

2626
import com.fasterxml.jackson.core.type.TypeReference;
27+
import org.reactivestreams.Publisher;
2728
import org.slf4j.Logger;
2829
import org.slf4j.LoggerFactory;
30+
import reactor.core.publisher.Mono;
2931

3032
import org.springframework.ai.chat.model.ToolContext;
3133
import org.springframework.ai.tool.ToolCallback;
@@ -96,6 +98,10 @@ public String call(String toolInput) {
9698

9799
@Override
98100
public String call(String toolInput, @Nullable ToolContext toolContext) {
101+
return callReactive(toolInput, toolContext).block();
102+
}
103+
104+
public Mono<String> callReactive(String toolInput, @Nullable ToolContext toolContext) {
99105
Assert.hasText(toolInput, "toolInput cannot be null or empty");
100106

101107
logger.debug("Starting execution of tool: {}", this.toolDefinition.name());
@@ -106,13 +112,13 @@ public String call(String toolInput, @Nullable ToolContext toolContext) {
106112

107113
Object[] methodArguments = buildMethodArguments(toolArguments, toolContext);
108114

109-
Object result = callMethod(methodArguments);
110-
111-
logger.debug("Successful execution of tool: {}", this.toolDefinition.name());
115+
return callMethod(methodArguments).map(result -> {
116+
logger.debug("Successful execution of tool: {}", this.toolDefinition.name());
112117

113-
Type returnType = this.toolMethod.getGenericReturnType();
118+
Type returnType = this.toolMethod.getGenericReturnType();
114119

115-
return this.toolCallResultConverter.convert(result, returnType);
120+
return this.toolCallResultConverter.convert(result, returnType);
121+
});
116122
}
117123

118124
private void validateToolContextSupport(@Nullable ToolContext toolContext) {
@@ -155,15 +161,16 @@ private Object buildTypedArgument(@Nullable Object value, Type type) {
155161
return JsonParser.fromJson(json, type);
156162
}
157163

158-
@Nullable
159-
private Object callMethod(Object[] methodArguments) {
164+
private Mono<Object> callMethod(Object[] methodArguments) {
160165
if (isObjectNotPublic() || isMethodNotPublic()) {
161166
this.toolMethod.setAccessible(true);
162167
}
163168

164-
Object result;
169+
final Mono<Object> result;
165170
try {
166-
result = this.toolMethod.invoke(this.toolObject, methodArguments);
171+
result = Publisher.class.isAssignableFrom(this.toolMethod.getReturnType())
172+
? Mono.from((Publisher<Object>) this.toolMethod.invoke(this.toolObject, methodArguments))
173+
: Mono.justOrEmpty(this.toolMethod.invoke(this.toolObject, methodArguments));
167174
}
168175
catch (IllegalAccessException ex) {
169176
throw new IllegalStateException("Could not access method: " + ex.getMessage(), ex);

0 commit comments

Comments
 (0)