diff --git a/contrib/spring-ai/README.md b/contrib/spring-ai/README.md index 0ce7de4fe..7c20b60ed 100644 --- a/contrib/spring-ai/README.md +++ b/contrib/spring-ai/README.md @@ -329,6 +329,80 @@ adk: metrics-enabled: true ``` +## Tool / MCP Bridge — Spring AI `ToolCallback` as ADK `BaseTool` + +In addition to wrapping Spring AI `ChatModel`s as ADK `BaseLlm`s, this library can wrap **any Spring AI `ToolCallback`** as an ADK `BaseTool` via `SpringAiToolCallbackBackedAdkTool`. This unlocks the full Spring AI tool ecosystem for ADK agents: + +- **MCP tools** — `SyncMcpToolCallback` / `AsyncMcpToolCallback` produced by `spring-ai-starter-mcp-client` from `spring.ai.mcp.client.*` properties +- **`@Tool`-annotated methods** — Spring AI's annotation-driven function calling +- **`FunctionToolCallback`** — programmatically declared tools +- Any other implementation of `org.springframework.ai.tool.ToolCallback` + +The bridge is the **reverse direction** of the existing `ToolConverter` (which goes ADK → Spring AI). Together they make ADK and Spring AI tool ecosystems fully interoperable. + +### How it works + +`SpringAiToolCallbackBackedAdkTool` reads `ToolCallback.getToolDefinition()` to extract the tool name, description, and JSON Schema. The schema is converted to ADK's `Schema` type via `Schema.fromJson(...)`; if parsing fails the bridge falls back to the `parametersJsonSchema(Object)` escape hatch (no hard failure). At invocation time the bridge serializes the `Map` arguments to JSON, dispatches to `ToolCallback.call(String)`, and parses the JSON response back to `Map`. Non-object responses (primitives / arrays / arbitrary strings) are wrapped under a `"result"` key for structural consistency. + +### Usage — MCP tools via Spring AI + +`application.yaml`: + +```yaml +spring: + ai: + mcp: + client: + sse: + connections: + filesystem: + url: http://localhost:3000 +``` + +Java: + +```java +import com.google.adk.agents.LlmAgent; +import com.google.adk.models.springai.SpringAI; +import com.google.adk.models.springai.bridge.SpringAiToolCallbackBackedAdkTool; +import org.springframework.ai.tool.ToolCallback; + +@Configuration +class AgentConfig { + + @Bean + public LlmAgent rootAgent(SpringAI springAI, List mcpToolCallbacks) { + return LlmAgent.builder() + .name("root_agent") + .model(springAI) + .tools(SpringAiToolCallbackBackedAdkTool.wrapAll(mcpToolCallbacks)) + .instruction("Use the available tools to answer the user.") + .build(); + } +} +``` + +That's it. The `List` is auto-injected by `spring-ai-starter-mcp-client`'s `McpToolCallbackAutoConfiguration`. The bridge converts every callback into a `BaseTool`. The agent uses them transparently. + +### Usage — single tool + +When you only need to wrap one callback: + +```java +ToolCallback callback = /* obtained from any Spring AI source */; +BaseTool adkTool = new SpringAiToolCallbackBackedAdkTool(callback); + +LlmAgent agent = LlmAgent.builder() + .name("my_agent") + .model(springAI) + .tools(List.of(adkTool)) + .build(); +``` + +### Coexistence with ADK's native MCP + +ADK ships its own MCP client in `com.google.adk.tools.mcp.*` (CLI / non-Spring-Boot scenarios). The two paths can be mixed at the `.tools(...)` boundary — both produce `BaseTool` instances — but it is strongly recommended to **pick one** in any given application. The Spring AI MCP route is the natural choice for Spring Boot apps because everything is property-driven; ADK's native `McpToolset` remains the right choice for non-Spring usage. + ## Architecture ### Core Components diff --git a/contrib/spring-ai/src/main/java/com/google/adk/models/springai/bridge/SpringAiToolCallbackBackedAdkTool.java b/contrib/spring-ai/src/main/java/com/google/adk/models/springai/bridge/SpringAiToolCallbackBackedAdkTool.java new file mode 100644 index 000000000..4b133da6f --- /dev/null +++ b/contrib/spring-ai/src/main/java/com/google/adk/models/springai/bridge/SpringAiToolCallbackBackedAdkTool.java @@ -0,0 +1,151 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + */ +package com.google.adk.models.springai.bridge; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.adk.JsonBaseModel; +import com.google.adk.tools.BaseTool; +import com.google.adk.tools.ToolContext; +import com.google.common.collect.ImmutableMap; +import com.google.genai.types.FunctionDeclaration; +import com.google.genai.types.Schema; +import io.reactivex.rxjava3.core.Single; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.stream.Collectors; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.tool.ToolCallback; + +/** + * ADK {@link BaseTool} backed by a Spring AI {@link ToolCallback}. + * + *

Adapts any Spring AI {@code ToolCallback} (including {@code + * SyncMcpToolCallback} / {@code AsyncMcpToolCallback} from {@code spring-ai-starter-mcp-client}, + * {@code FunctionToolCallback}, {@code @Tool}-annotated method callbacks, etc.) into the ADK tool + * model so they can be attached to an {@link com.google.adk.agents.LlmAgent}. + * + *

Schema is extracted from {@link ToolCallback#getToolDefinition()}: {@code name}, {@code + * description}, and the JSON Schema returned by {@code inputSchema()} are mapped into the ADK + * {@link FunctionDeclaration}/{@link Schema} pair via {@link Schema#fromJson(String)}. + * + *

Invocation: the ADK runtime calls {@link #runAsync(Map, ToolContext)} with parsed arguments → + * this class serializes them back to JSON and dispatches to {@link ToolCallback#call(String)}, + * which returns a JSON string. The string is parsed back into a {@code Map} for the + * ADK flow. When the underlying callback's result is not a JSON object (e.g. a primitive or array), + * the value is wrapped under a {@code "result"} key. + */ +public class SpringAiToolCallbackBackedAdkTool extends BaseTool { + + private static final Logger logger = + LoggerFactory.getLogger(SpringAiToolCallbackBackedAdkTool.class); + + private static final TypeReference> MAP_TYPE = new TypeReference<>() {}; + + private final ToolCallback toolCallback; + private final ObjectMapper objectMapper; + private final Optional declaration; + + /** Wraps a Spring AI {@link ToolCallback} as an ADK tool using the default ADK JSON mapper. */ + public SpringAiToolCallbackBackedAdkTool(ToolCallback toolCallback) { + this(toolCallback, JsonBaseModel.getMapper()); + } + + /** Wraps a Spring AI {@link ToolCallback} as an ADK tool with a custom JSON mapper. */ + public SpringAiToolCallbackBackedAdkTool(ToolCallback toolCallback, ObjectMapper objectMapper) { + super( + Objects.requireNonNull(toolCallback, "toolCallback").getToolDefinition().name(), + toolCallback.getToolDefinition().description()); + this.toolCallback = toolCallback; + this.objectMapper = Objects.requireNonNull(objectMapper, "objectMapper"); + this.declaration = buildDeclaration(toolCallback); + } + + /** + * Converts every {@link ToolCallback} in the input list into a {@link BaseTool}. Useful when + * fanning out a {@code List} (e.g. the result of {@code McpToolCallbackProvider}) + * to an agent's {@code .tools(...)} list. + */ + public static List wrapAll(List toolCallbacks) { + return toolCallbacks.stream() + .map(SpringAiToolCallbackBackedAdkTool::new) + .collect(Collectors.toList()); + } + + @Override + public Optional declaration() { + return declaration; + } + + @Override + public Single> runAsync(Map args, ToolContext toolContext) { + return Single.fromCallable( + () -> { + String requestJson = objectMapper.writeValueAsString(args == null ? Map.of() : args); + if (logger.isDebugEnabled()) { + logger.debug("Invoking Spring AI tool '{}' with args: {}", name(), requestJson); + } + String responseJson = toolCallback.call(requestJson); + return parseResponse(responseJson); + }); + } + + private Map parseResponse(String responseJson) { + if (responseJson == null || responseJson.isBlank()) { + return ImmutableMap.of(); + } + try { + return objectMapper.readValue(responseJson, MAP_TYPE); + } catch (Exception notAnObject) { + // The callback returned a primitive, array, or arbitrary string. Wrap so the agent gets a + // structured result rather than a parse failure. + Object decoded = tryDecodeAsJsonValue(responseJson); + return ImmutableMap.of("result", decoded); + } + } + + private Object tryDecodeAsJsonValue(String responseJson) { + try { + return objectMapper.readTree(responseJson); + } catch (Exception notJson) { + return responseJson; + } + } + + /** Exposed for tests and downstream tooling. */ + public ToolCallback toolCallback() { + return toolCallback; + } + + private static Optional buildDeclaration(ToolCallback toolCallback) { + var def = toolCallback.getToolDefinition(); + FunctionDeclaration.Builder builder = FunctionDeclaration.builder().name(def.name()); + if (def.description() != null && !def.description().isBlank()) { + builder.description(def.description()); + } + String inputSchema = def.inputSchema(); + if (inputSchema != null && !inputSchema.isBlank()) { + try { + builder.parameters(Schema.fromJson(inputSchema)); + } catch (Exception parseFailed) { + logger.warn( + "Could not parse Spring AI tool '{}' input schema as ADK Schema; falling back to raw" + + " JSON schema. Cause: {}", + def.name(), + parseFailed.getMessage()); + builder.parametersJsonSchema(inputSchema); + } + } + return Optional.of(builder.build()); + } +} diff --git a/contrib/spring-ai/src/test/java/com/google/adk/models/springai/bridge/SpringAiToolCallbackBackedAdkToolTest.java b/contrib/spring-ai/src/test/java/com/google/adk/models/springai/bridge/SpringAiToolCallbackBackedAdkToolTest.java new file mode 100644 index 000000000..8847c1ce7 --- /dev/null +++ b/contrib/spring-ai/src/test/java/com/google/adk/models/springai/bridge/SpringAiToolCallbackBackedAdkToolTest.java @@ -0,0 +1,128 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + */ +package com.google.adk.models.springai.bridge; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.util.List; +import java.util.Map; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import org.springframework.ai.tool.ToolCallback; +import org.springframework.ai.tool.definition.DefaultToolDefinition; +import org.springframework.ai.tool.definition.ToolDefinition; + +class SpringAiToolCallbackBackedAdkToolTest { + + private static final String SCHEMA = + "{\"type\":\"object\",\"properties\":{\"city\":{\"type\":\"string\"}},\"required\":[\"city\"]}"; + + @Test + void declaration_isBuiltFromToolDefinition() { + ToolCallback callback = mockCallback("weather", "Returns weather for a city", SCHEMA); + + SpringAiToolCallbackBackedAdkTool tool = new SpringAiToolCallbackBackedAdkTool(callback); + + assertThat(tool.name()).isEqualTo("weather"); + assertThat(tool.description()).isEqualTo("Returns weather for a city"); + assertThat(tool.declaration()).isPresent(); + assertThat(tool.declaration().get().name()).hasValue("weather"); + assertThat(tool.declaration().get().description()).hasValue("Returns weather for a city"); + assertThat(tool.declaration().get().parameters()).isPresent(); + } + + @Test + void runAsync_serializesArgs_andParsesJsonObjectResponse() throws Exception { + ToolCallback callback = mockCallback("weather", "desc", SCHEMA); + when(callback.call(anyString())).thenReturn("{\"forecast\":\"sunny\",\"temp\":22}"); + + SpringAiToolCallbackBackedAdkTool tool = new SpringAiToolCallbackBackedAdkTool(callback); + Map result = tool.runAsync(Map.of("city", "Paris"), null).blockingGet(); + + assertThat(result).containsEntry("forecast", "sunny").containsEntry("temp", 22); + + ArgumentCaptor argsCaptor = ArgumentCaptor.forClass(String.class); + verify(callback).call(argsCaptor.capture()); + assertThat(argsCaptor.getValue()).contains("\"city\":\"Paris\""); + } + + @Test + void runAsync_wrapsScalarResponse_underResultKey() { + ToolCallback callback = mockCallback("ping", "desc", SCHEMA); + when(callback.call(anyString())).thenReturn("\"pong\""); + + SpringAiToolCallbackBackedAdkTool tool = new SpringAiToolCallbackBackedAdkTool(callback); + Map result = tool.runAsync(Map.of(), null).blockingGet(); + + assertThat(result).containsKey("result"); + } + + @Test + void runAsync_emptyResponse_yieldsEmptyMap() { + ToolCallback callback = mockCallback("noop", "desc", SCHEMA); + when(callback.call(anyString())).thenReturn(""); + + SpringAiToolCallbackBackedAdkTool tool = new SpringAiToolCallbackBackedAdkTool(callback); + Map result = tool.runAsync(Map.of(), null).blockingGet(); + + assertThat(result).isEmpty(); + } + + @Test + void runAsync_nullArgs_sendsEmptyObject() { + ToolCallback callback = mockCallback("noop", "desc", SCHEMA); + when(callback.call(anyString())).thenReturn("{}"); + + SpringAiToolCallbackBackedAdkTool tool = new SpringAiToolCallbackBackedAdkTool(callback); + tool.runAsync(null, null).blockingGet(); + + ArgumentCaptor argsCaptor = ArgumentCaptor.forClass(String.class); + verify(callback).call(argsCaptor.capture()); + assertThat(argsCaptor.getValue()).isEqualTo("{}"); + } + + @Test + void wrapAll_convertsEveryCallback() { + ToolCallback a = mockCallback("a", "tool a", SCHEMA); + ToolCallback b = mockCallback("b", "tool b", SCHEMA); + + var wrapped = SpringAiToolCallbackBackedAdkTool.wrapAll(List.of(a, b)); + + assertThat(wrapped).hasSize(2); + assertThat(wrapped.get(0).name()).isEqualTo("a"); + assertThat(wrapped.get(1).name()).isEqualTo("b"); + } + + @Test + void invalidSchema_fallsBackToRawJsonSchema_withoutThrowing() { + ToolCallback callback = mockCallback("malformed", "desc", "not valid json {["); + + SpringAiToolCallbackBackedAdkTool tool = new SpringAiToolCallbackBackedAdkTool(callback); + + assertThat(tool.declaration()).isPresent(); + assertThat(tool.declaration().get().parametersJsonSchema()).isPresent(); + } + + private static ToolCallback mockCallback(String name, String description, String inputSchema) { + ToolCallback cb = mock(ToolCallback.class); + ToolDefinition def = + DefaultToolDefinition.builder() + .name(name) + .description(description) + .inputSchema(inputSchema) + .build(); + when(cb.getToolDefinition()).thenReturn(def); + return cb; + } +}