Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions contrib/spring-ai/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,80 @@ adk:
metrics-enabled: true
```

## Tool / MCP Bridge — Spring AI `ToolCallback` as ADK `BaseTool`

In addition to wrapping Spring AI `ChatModel`s as ADK `BaseLlm`s, this library can wrap **any Spring AI `ToolCallback`** as an ADK `BaseTool` via `SpringAiToolCallbackBackedAdkTool`. This unlocks the full Spring AI tool ecosystem for ADK agents:

- **MCP tools** — `SyncMcpToolCallback` / `AsyncMcpToolCallback` produced by `spring-ai-starter-mcp-client` from `spring.ai.mcp.client.*` properties
- **`@Tool`-annotated methods** — Spring AI's annotation-driven function calling
- **`FunctionToolCallback`** — programmatically declared tools
- Any other implementation of `org.springframework.ai.tool.ToolCallback`

The bridge is the **reverse direction** of the existing `ToolConverter` (which goes ADK → Spring AI). Together they make ADK and Spring AI tool ecosystems fully interoperable.

### How it works

`SpringAiToolCallbackBackedAdkTool` reads `ToolCallback.getToolDefinition()` to extract the tool name, description, and JSON Schema. The schema is converted to ADK's `Schema` type via `Schema.fromJson(...)`; if parsing fails the bridge falls back to the `parametersJsonSchema(Object)` escape hatch (no hard failure). At invocation time the bridge serializes the `Map<String, Object>` arguments to JSON, dispatches to `ToolCallback.call(String)`, and parses the JSON response back to `Map<String, Object>`. Non-object responses (primitives / arrays / arbitrary strings) are wrapped under a `"result"` key for structural consistency.

### Usage — MCP tools via Spring AI

`application.yaml`:

```yaml
spring:
ai:
mcp:
client:
sse:
connections:
filesystem:
url: http://localhost:3000
```

Java:

```java
import com.google.adk.agents.LlmAgent;
import com.google.adk.models.springai.SpringAI;
import com.google.adk.models.springai.bridge.SpringAiToolCallbackBackedAdkTool;
import org.springframework.ai.tool.ToolCallback;

@Configuration
class AgentConfig {

@Bean
public LlmAgent rootAgent(SpringAI springAI, List<ToolCallback> mcpToolCallbacks) {
return LlmAgent.builder()
.name("root_agent")
.model(springAI)
.tools(SpringAiToolCallbackBackedAdkTool.wrapAll(mcpToolCallbacks))
.instruction("Use the available tools to answer the user.")
.build();
}
}
```

That's it. The `List<ToolCallback>` is auto-injected by `spring-ai-starter-mcp-client`'s `McpToolCallbackAutoConfiguration`. The bridge converts every callback into a `BaseTool`. The agent uses them transparently.

### Usage — single tool

When you only need to wrap one callback:

```java
ToolCallback callback = /* obtained from any Spring AI source */;
BaseTool adkTool = new SpringAiToolCallbackBackedAdkTool(callback);

LlmAgent agent = LlmAgent.builder()
.name("my_agent")
.model(springAI)
.tools(List.of(adkTool))
.build();
```

### Coexistence with ADK's native MCP

ADK ships its own MCP client in `com.google.adk.tools.mcp.*` (CLI / non-Spring-Boot scenarios). The two paths can be mixed at the `.tools(...)` boundary — both produce `BaseTool` instances — but it is strongly recommended to **pick one** in any given application. The Spring AI MCP route is the natural choice for Spring Boot apps because everything is property-driven; ADK's native `McpToolset` remains the right choice for non-Spring usage.

## Architecture

### Core Components
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
/*
* Copyright 2026 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*/
package com.google.adk.models.springai.bridge;

import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.adk.JsonBaseModel;
import com.google.adk.tools.BaseTool;
import com.google.adk.tools.ToolContext;
import com.google.common.collect.ImmutableMap;
import com.google.genai.types.FunctionDeclaration;
import com.google.genai.types.Schema;
import io.reactivex.rxjava3.core.Single;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.tool.ToolCallback;

/**
* ADK {@link BaseTool} backed by a Spring AI {@link ToolCallback}.
*
* <p>Adapts <strong>any</strong> Spring AI {@code ToolCallback} (including {@code
* SyncMcpToolCallback} / {@code AsyncMcpToolCallback} from {@code spring-ai-starter-mcp-client},
* {@code FunctionToolCallback}, {@code @Tool}-annotated method callbacks, etc.) into the ADK tool
* model so they can be attached to an {@link com.google.adk.agents.LlmAgent}.
*
* <p>Schema is extracted from {@link ToolCallback#getToolDefinition()}: {@code name}, {@code
* description}, and the JSON Schema returned by {@code inputSchema()} are mapped into the ADK
* {@link FunctionDeclaration}/{@link Schema} pair via {@link Schema#fromJson(String)}.
*
* <p>Invocation: the ADK runtime calls {@link #runAsync(Map, ToolContext)} with parsed arguments →
* this class serializes them back to JSON and dispatches to {@link ToolCallback#call(String)},
* which returns a JSON string. The string is parsed back into a {@code Map<String, Object>} for the
* ADK flow. When the underlying callback's result is not a JSON object (e.g. a primitive or array),
* the value is wrapped under a {@code "result"} key.
*/
public class SpringAiToolCallbackBackedAdkTool extends BaseTool {

private static final Logger logger =
LoggerFactory.getLogger(SpringAiToolCallbackBackedAdkTool.class);

private static final TypeReference<Map<String, Object>> MAP_TYPE = new TypeReference<>() {};

private final ToolCallback toolCallback;
private final ObjectMapper objectMapper;
private final Optional<FunctionDeclaration> declaration;

/** Wraps a Spring AI {@link ToolCallback} as an ADK tool using the default ADK JSON mapper. */
public SpringAiToolCallbackBackedAdkTool(ToolCallback toolCallback) {
this(toolCallback, JsonBaseModel.getMapper());
}

/** Wraps a Spring AI {@link ToolCallback} as an ADK tool with a custom JSON mapper. */
public SpringAiToolCallbackBackedAdkTool(ToolCallback toolCallback, ObjectMapper objectMapper) {
super(
Objects.requireNonNull(toolCallback, "toolCallback").getToolDefinition().name(),
toolCallback.getToolDefinition().description());
this.toolCallback = toolCallback;
this.objectMapper = Objects.requireNonNull(objectMapper, "objectMapper");
this.declaration = buildDeclaration(toolCallback);
}

/**
* Converts every {@link ToolCallback} in the input list into a {@link BaseTool}. Useful when
* fanning out a {@code List<ToolCallback>} (e.g. the result of {@code McpToolCallbackProvider})
* to an agent's {@code .tools(...)} list.
*/
public static List<BaseTool> wrapAll(List<? extends ToolCallback> toolCallbacks) {
return toolCallbacks.stream()
.map(SpringAiToolCallbackBackedAdkTool::new)
.collect(Collectors.toList());
}

@Override
public Optional<FunctionDeclaration> declaration() {
return declaration;
}

@Override
public Single<Map<String, Object>> runAsync(Map<String, Object> args, ToolContext toolContext) {
return Single.fromCallable(
() -> {
String requestJson = objectMapper.writeValueAsString(args == null ? Map.of() : args);
if (logger.isDebugEnabled()) {
logger.debug("Invoking Spring AI tool '{}' with args: {}", name(), requestJson);
}
String responseJson = toolCallback.call(requestJson);
return parseResponse(responseJson);
});
}

private Map<String, Object> parseResponse(String responseJson) {
if (responseJson == null || responseJson.isBlank()) {
return ImmutableMap.of();
}
try {
return objectMapper.readValue(responseJson, MAP_TYPE);
} catch (Exception notAnObject) {
// The callback returned a primitive, array, or arbitrary string. Wrap so the agent gets a
// structured result rather than a parse failure.
Object decoded = tryDecodeAsJsonValue(responseJson);
return ImmutableMap.of("result", decoded);
}
}

private Object tryDecodeAsJsonValue(String responseJson) {
try {
return objectMapper.readTree(responseJson);
} catch (Exception notJson) {
return responseJson;
}
}

/** Exposed for tests and downstream tooling. */
public ToolCallback toolCallback() {
return toolCallback;
}

private static Optional<FunctionDeclaration> buildDeclaration(ToolCallback toolCallback) {
var def = toolCallback.getToolDefinition();
FunctionDeclaration.Builder builder = FunctionDeclaration.builder().name(def.name());
if (def.description() != null && !def.description().isBlank()) {
builder.description(def.description());
}
String inputSchema = def.inputSchema();
if (inputSchema != null && !inputSchema.isBlank()) {
try {
builder.parameters(Schema.fromJson(inputSchema));
} catch (Exception parseFailed) {
logger.warn(
"Could not parse Spring AI tool '{}' input schema as ADK Schema; falling back to raw"
+ " JSON schema. Cause: {}",
def.name(),
parseFailed.getMessage());
builder.parametersJsonSchema(inputSchema);
}
}
return Optional.of(builder.build());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
/*
* Copyright 2026 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*/
package com.google.adk.models.springai.bridge;

import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import java.util.List;
import java.util.Map;
import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.ai.tool.definition.DefaultToolDefinition;
import org.springframework.ai.tool.definition.ToolDefinition;

class SpringAiToolCallbackBackedAdkToolTest {

private static final String SCHEMA =
"{\"type\":\"object\",\"properties\":{\"city\":{\"type\":\"string\"}},\"required\":[\"city\"]}";

@Test
void declaration_isBuiltFromToolDefinition() {
ToolCallback callback = mockCallback("weather", "Returns weather for a city", SCHEMA);

SpringAiToolCallbackBackedAdkTool tool = new SpringAiToolCallbackBackedAdkTool(callback);

assertThat(tool.name()).isEqualTo("weather");
assertThat(tool.description()).isEqualTo("Returns weather for a city");
assertThat(tool.declaration()).isPresent();
assertThat(tool.declaration().get().name()).hasValue("weather");
assertThat(tool.declaration().get().description()).hasValue("Returns weather for a city");
assertThat(tool.declaration().get().parameters()).isPresent();
}

@Test
void runAsync_serializesArgs_andParsesJsonObjectResponse() throws Exception {
ToolCallback callback = mockCallback("weather", "desc", SCHEMA);
when(callback.call(anyString())).thenReturn("{\"forecast\":\"sunny\",\"temp\":22}");

SpringAiToolCallbackBackedAdkTool tool = new SpringAiToolCallbackBackedAdkTool(callback);
Map<String, Object> result = tool.runAsync(Map.of("city", "Paris"), null).blockingGet();

assertThat(result).containsEntry("forecast", "sunny").containsEntry("temp", 22);

ArgumentCaptor<String> argsCaptor = ArgumentCaptor.forClass(String.class);
verify(callback).call(argsCaptor.capture());
assertThat(argsCaptor.getValue()).contains("\"city\":\"Paris\"");
}

@Test
void runAsync_wrapsScalarResponse_underResultKey() {
ToolCallback callback = mockCallback("ping", "desc", SCHEMA);
when(callback.call(anyString())).thenReturn("\"pong\"");

SpringAiToolCallbackBackedAdkTool tool = new SpringAiToolCallbackBackedAdkTool(callback);
Map<String, Object> result = tool.runAsync(Map.of(), null).blockingGet();

assertThat(result).containsKey("result");
}

@Test
void runAsync_emptyResponse_yieldsEmptyMap() {
ToolCallback callback = mockCallback("noop", "desc", SCHEMA);
when(callback.call(anyString())).thenReturn("");

SpringAiToolCallbackBackedAdkTool tool = new SpringAiToolCallbackBackedAdkTool(callback);
Map<String, Object> result = tool.runAsync(Map.of(), null).blockingGet();

assertThat(result).isEmpty();
}

@Test
void runAsync_nullArgs_sendsEmptyObject() {
ToolCallback callback = mockCallback("noop", "desc", SCHEMA);
when(callback.call(anyString())).thenReturn("{}");

SpringAiToolCallbackBackedAdkTool tool = new SpringAiToolCallbackBackedAdkTool(callback);
tool.runAsync(null, null).blockingGet();

ArgumentCaptor<String> argsCaptor = ArgumentCaptor.forClass(String.class);
verify(callback).call(argsCaptor.capture());
assertThat(argsCaptor.getValue()).isEqualTo("{}");
}

@Test
void wrapAll_convertsEveryCallback() {
ToolCallback a = mockCallback("a", "tool a", SCHEMA);
ToolCallback b = mockCallback("b", "tool b", SCHEMA);

var wrapped = SpringAiToolCallbackBackedAdkTool.wrapAll(List.of(a, b));

assertThat(wrapped).hasSize(2);
assertThat(wrapped.get(0).name()).isEqualTo("a");
assertThat(wrapped.get(1).name()).isEqualTo("b");
}

@Test
void invalidSchema_fallsBackToRawJsonSchema_withoutThrowing() {
ToolCallback callback = mockCallback("malformed", "desc", "not valid json {[");

SpringAiToolCallbackBackedAdkTool tool = new SpringAiToolCallbackBackedAdkTool(callback);

assertThat(tool.declaration()).isPresent();
assertThat(tool.declaration().get().parametersJsonSchema()).isPresent();
}

private static ToolCallback mockCallback(String name, String description, String inputSchema) {
ToolCallback cb = mock(ToolCallback.class);
ToolDefinition def =
DefaultToolDefinition.builder()
.name(name)
.description(description)
.inputSchema(inputSchema)
.build();
when(cb.getToolDefinition()).thenReturn(def);
return cb;
}
}