From c29b9c60bfe0173fc080ed575e3c65a5c1e99bea Mon Sep 17 00:00:00 2001 From: Nebrass Lamouchi Date: Mon, 5 Jan 2026 18:17:59 +0400 Subject: [PATCH] feat: Add ReasoningBank for reusable reasoning strategies Implements ReasoningBank feature based on arXiv:2509.25140 paper. Core components: - ReasoningStrategy: distilled, reusable reasoning approach - ReasoningTrace: captures raw task execution data - BaseReasoningBankService: interface for storage/retrieval - InMemoryReasoningBankService: in-memory implementation with keyword matching - LoadReasoningStrategyTool: tool for agents to retrieve relevant strategies - SearchReasoningResponse: response model for strategy search Integration: - Added reasoningBankService to InvocationContext - Added searchReasoningStrategies() to ToolContext Includes 20 unit tests covering data models and service functionality. --- .../google/adk/agents/InvocationContext.java | 24 ++ .../reasoning/BaseReasoningBankService.java | 68 +++++ .../InMemoryReasoningBankService.java | 180 +++++++++++++ .../adk/reasoning/ReasoningStrategy.java | 130 +++++++++ .../google/adk/reasoning/ReasoningTrace.java | 139 ++++++++++ .../reasoning/SearchReasoningResponse.java | 49 ++++ .../tools/LoadReasoningStrategyResponse.java | 25 ++ .../adk/tools/LoadReasoningStrategyTool.java | 84 ++++++ .../com/google/adk/tools/ToolContext.java | 11 + .../InMemoryReasoningBankServiceTest.java | 251 ++++++++++++++++++ .../adk/reasoning/ReasoningStrategyTest.java | 99 +++++++ .../adk/reasoning/ReasoningTraceTest.java | 104 ++++++++ 12 files changed, 1164 insertions(+) create mode 100644 core/src/main/java/com/google/adk/reasoning/BaseReasoningBankService.java create mode 100644 core/src/main/java/com/google/adk/reasoning/InMemoryReasoningBankService.java create mode 100644 core/src/main/java/com/google/adk/reasoning/ReasoningStrategy.java create mode 100644 core/src/main/java/com/google/adk/reasoning/ReasoningTrace.java create mode 100644 core/src/main/java/com/google/adk/reasoning/SearchReasoningResponse.java create mode 100644 core/src/main/java/com/google/adk/tools/LoadReasoningStrategyResponse.java create mode 100644 core/src/main/java/com/google/adk/tools/LoadReasoningStrategyTool.java create mode 100644 core/src/test/java/com/google/adk/reasoning/InMemoryReasoningBankServiceTest.java create mode 100644 core/src/test/java/com/google/adk/reasoning/ReasoningStrategyTest.java create mode 100644 core/src/test/java/com/google/adk/reasoning/ReasoningTraceTest.java diff --git a/core/src/main/java/com/google/adk/agents/InvocationContext.java b/core/src/main/java/com/google/adk/agents/InvocationContext.java index 9491353fd..50e74060c 100644 --- a/core/src/main/java/com/google/adk/agents/InvocationContext.java +++ b/core/src/main/java/com/google/adk/agents/InvocationContext.java @@ -22,6 +22,7 @@ import com.google.adk.memory.BaseMemoryService; import com.google.adk.models.LlmCallsLimitExceededException; import com.google.adk.plugins.PluginManager; +import com.google.adk.reasoning.BaseReasoningBankService; import com.google.adk.sessions.BaseSessionService; import com.google.adk.sessions.Session; import com.google.common.collect.ImmutableSet; @@ -42,6 +43,7 @@ public class InvocationContext { private final BaseSessionService sessionService; private final BaseArtifactService artifactService; private final BaseMemoryService memoryService; + private final BaseReasoningBankService reasoningBankService; private final PluginManager pluginManager; private final Optional liveRequestQueue; private final Map activeStreamingTools = new ConcurrentHashMap<>(); @@ -60,6 +62,7 @@ private InvocationContext(Builder builder) { this.sessionService = builder.sessionService; this.artifactService = builder.artifactService; this.memoryService = builder.memoryService; + this.reasoningBankService = builder.reasoningBankService; this.pluginManager = builder.pluginManager; this.liveRequestQueue = builder.liveRequestQueue; this.branch = builder.branch; @@ -204,6 +207,7 @@ public static InvocationContext copyOf(InvocationContext other) { .sessionService(other.sessionService) .artifactService(other.artifactService) .memoryService(other.memoryService) + .reasoningBankService(other.reasoningBankService) .pluginManager(other.pluginManager) .liveRequestQueue(other.liveRequestQueue) .branch(other.branch) @@ -234,6 +238,11 @@ public BaseMemoryService memoryService() { return memoryService; } + /** Returns the reasoning bank service for accessing reasoning strategies. */ + public BaseReasoningBankService reasoningBankService() { + return reasoningBankService; + } + /** Returns the plugin manager for accessing tools and plugins. */ public PluginManager pluginManager() { return pluginManager; @@ -376,6 +385,7 @@ public static class Builder { private BaseSessionService sessionService; private BaseArtifactService artifactService; private BaseMemoryService memoryService; + private BaseReasoningBankService reasoningBankService; private PluginManager pluginManager = new PluginManager(); private Optional liveRequestQueue = Optional.empty(); private Optional branch = Optional.empty(); @@ -423,6 +433,18 @@ public Builder memoryService(BaseMemoryService memoryService) { return this; } + /** + * Sets the reasoning bank service for accessing reasoning strategies. + * + * @param reasoningBankService the reasoning bank service to use. + * @return this builder instance for chaining. + */ + @CanIgnoreReturnValue + public Builder reasoningBankService(BaseReasoningBankService reasoningBankService) { + this.reasoningBankService = reasoningBankService; + return this; + } + /** * Sets the plugin manager for accessing tools and plugins. * @@ -608,6 +630,7 @@ public boolean equals(Object o) { && Objects.equals(sessionService, that.sessionService) && Objects.equals(artifactService, that.artifactService) && Objects.equals(memoryService, that.memoryService) + && Objects.equals(reasoningBankService, that.reasoningBankService) && Objects.equals(pluginManager, that.pluginManager) && Objects.equals(liveRequestQueue, that.liveRequestQueue) && Objects.equals(activeStreamingTools, that.activeStreamingTools) @@ -626,6 +649,7 @@ public int hashCode() { sessionService, artifactService, memoryService, + reasoningBankService, pluginManager, liveRequestQueue, activeStreamingTools, diff --git a/core/src/main/java/com/google/adk/reasoning/BaseReasoningBankService.java b/core/src/main/java/com/google/adk/reasoning/BaseReasoningBankService.java new file mode 100644 index 000000000..1f5877e68 --- /dev/null +++ b/core/src/main/java/com/google/adk/reasoning/BaseReasoningBankService.java @@ -0,0 +1,68 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.reasoning; + +import io.reactivex.rxjava3.core.Completable; +import io.reactivex.rxjava3.core.Single; + +/** + * Base contract for reasoning bank services. + * + *

The service provides functionalities to store and retrieve reasoning strategies that can be + * used to augment LLM prompts with relevant problem-solving approaches. + * + *

Based on the ReasoningBank paper (arXiv:2509.25140). + */ +public interface BaseReasoningBankService { + + /** + * Stores a reasoning strategy in the bank. + * + * @param appName The name of the application. + * @param strategy The strategy to store. + * @return A Completable that completes when the strategy is stored. + */ + Completable storeStrategy(String appName, ReasoningStrategy strategy); + + /** + * Stores a reasoning trace for later distillation into strategies. + * + * @param appName The name of the application. + * @param trace The trace to store. + * @return A Completable that completes when the trace is stored. + */ + Completable storeTrace(String appName, ReasoningTrace trace); + + /** + * Searches for reasoning strategies that match the given query. + * + * @param appName The name of the application. + * @param query The query to search for (typically a task description). + * @return A {@link SearchReasoningResponse} containing matching strategies. + */ + Single searchStrategies(String appName, String query); + + /** + * Searches for reasoning strategies that match the given query with a limit. + * + * @param appName The name of the application. + * @param query The query to search for. + * @param maxResults Maximum number of strategies to return. + * @return A {@link SearchReasoningResponse} containing matching strategies. + */ + Single searchStrategies(String appName, String query, int maxResults); +} diff --git a/core/src/main/java/com/google/adk/reasoning/InMemoryReasoningBankService.java b/core/src/main/java/com/google/adk/reasoning/InMemoryReasoningBankService.java new file mode 100644 index 000000000..8c2fa6157 --- /dev/null +++ b/core/src/main/java/com/google/adk/reasoning/InMemoryReasoningBankService.java @@ -0,0 +1,180 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.reasoning; + +import com.google.common.collect.ImmutableSet; +import io.reactivex.rxjava3.core.Completable; +import io.reactivex.rxjava3.core.Single; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import java.util.stream.Collectors; + +/** + * An in-memory reasoning bank service for prototyping purposes only. + * + *

Uses keyword matching instead of semantic search. For production use, consider implementing a + * service backed by vector embeddings for semantic similarity matching. + */ +public final class InMemoryReasoningBankService implements BaseReasoningBankService { + + private static final int DEFAULT_MAX_RESULTS = 5; + + // Pattern to extract words for keyword matching. + private static final Pattern WORD_PATTERN = Pattern.compile("[A-Za-z]+"); + + /** Keys are app names, values are lists of strategies. */ + private final Map> strategies; + + /** Keys are app names, values are lists of traces. */ + private final Map> traces; + + public InMemoryReasoningBankService() { + this.strategies = new ConcurrentHashMap<>(); + this.traces = new ConcurrentHashMap<>(); + } + + @Override + public Completable storeStrategy(String appName, ReasoningStrategy strategy) { + return Completable.fromAction( + () -> { + List appStrategies = + strategies.computeIfAbsent( + appName, k -> Collections.synchronizedList(new ArrayList<>())); + appStrategies.add(strategy); + }); + } + + @Override + public Completable storeTrace(String appName, ReasoningTrace trace) { + return Completable.fromAction( + () -> { + List appTraces = + traces.computeIfAbsent(appName, k -> Collections.synchronizedList(new ArrayList<>())); + appTraces.add(trace); + }); + } + + @Override + public Single searchStrategies(String appName, String query) { + return searchStrategies(appName, query, DEFAULT_MAX_RESULTS); + } + + @Override + public Single searchStrategies( + String appName, String query, int maxResults) { + return Single.fromCallable( + () -> { + if (!strategies.containsKey(appName)) { + return SearchReasoningResponse.builder().build(); + } + + List appStrategies = strategies.get(appName); + ImmutableSet queryWords = extractWords(query); + + if (queryWords.isEmpty()) { + return SearchReasoningResponse.builder().build(); + } + + List scoredStrategies = new ArrayList<>(); + + for (ReasoningStrategy strategy : appStrategies) { + int score = calculateMatchScore(strategy, queryWords); + if (score > 0) { + scoredStrategies.add(new ScoredStrategy(strategy, score)); + } + } + + // Sort by score descending + scoredStrategies.sort((a, b) -> Integer.compare(b.score, a.score)); + + // Take top results + List matchingStrategies = + scoredStrategies.stream() + .map(scoredStrategy -> scoredStrategy.strategy) + .limit(maxResults) + .collect(Collectors.toList()); + + return SearchReasoningResponse.builder().setStrategies(matchingStrategies).build(); + }); + } + + private int calculateMatchScore(ReasoningStrategy strategy, Set queryWords) { + int score = 0; + + // Check problem pattern + Set patternWords = extractWords(strategy.problemPattern()); + score += countOverlap(queryWords, patternWords) * 3; // Weight pattern matches higher + + // Check name + Set nameWords = extractWords(strategy.name()); + score += countOverlap(queryWords, nameWords) * 2; + + // Check tags + for (String tag : strategy.tags()) { + Set tagWords = extractWords(tag); + score += countOverlap(queryWords, tagWords); + } + + // Check steps (lower weight) + for (String step : strategy.steps()) { + Set stepWords = extractWords(step); + if (!Collections.disjoint(queryWords, stepWords)) { + score += 1; + } + } + + return score; + } + + private int countOverlap(Set set1, Set set2) { + Set intersection = new HashSet<>(set1); + intersection.retainAll(set2); + return intersection.size(); + } + + private ImmutableSet extractWords(String text) { + if (text == null || text.isEmpty()) { + return ImmutableSet.of(); + } + + Set words = new HashSet<>(); + Matcher matcher = WORD_PATTERN.matcher(text); + while (matcher.find()) { + words.add(matcher.group().toLowerCase(Locale.ROOT)); + } + return ImmutableSet.copyOf(words); + } + + /** Helper class for scoring strategies during search. */ + private static class ScoredStrategy { + final ReasoningStrategy strategy; + final int score; + + ScoredStrategy(ReasoningStrategy strategy, int score) { + this.strategy = strategy; + this.score = score; + } + } +} diff --git a/core/src/main/java/com/google/adk/reasoning/ReasoningStrategy.java b/core/src/main/java/com/google/adk/reasoning/ReasoningStrategy.java new file mode 100644 index 000000000..a16e66e86 --- /dev/null +++ b/core/src/main/java/com/google/adk/reasoning/ReasoningStrategy.java @@ -0,0 +1,130 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.reasoning; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.annotation.JsonDeserialize; +import com.google.auto.value.AutoValue; +import com.google.common.collect.ImmutableList; +import java.time.Instant; +import javax.annotation.Nullable; + +/** + * Represents a distilled reasoning strategy that can be reused across tasks. + * + *

A reasoning strategy captures a generalized approach to solving a class of problems, distilled + * from one or more successful task executions. Strategies include the problem pattern they apply + * to, the reasoning steps to follow, and optional metadata for retrieval and organization. + * + *

Based on the ReasoningBank paper (arXiv:2509.25140). + */ +@AutoValue +@JsonDeserialize(builder = ReasoningStrategy.Builder.class) +public abstract class ReasoningStrategy { + + /** Returns the unique identifier for this strategy. */ + @JsonProperty("id") + public abstract String id(); + + /** Returns the name or title of this strategy. */ + @JsonProperty("name") + public abstract String name(); + + /** + * Returns the description of the problem pattern this strategy applies to. + * + *

This is used for matching strategies to new tasks. + */ + @JsonProperty("problemPattern") + public abstract String problemPattern(); + + /** + * Returns the ordered list of reasoning steps that comprise this strategy. + * + *

Each step describes a phase of the reasoning process. + */ + @JsonProperty("steps") + public abstract ImmutableList steps(); + + /** Returns optional tags for categorization and retrieval. */ + @JsonProperty("tags") + public abstract ImmutableList tags(); + + /** Returns the timestamp when this strategy was created. */ + @Nullable + @JsonProperty("createdAt") + public abstract String createdAt(); + + /** Returns a new builder for creating a {@link ReasoningStrategy}. */ + public static Builder builder() { + return new AutoValue_ReasoningStrategy.Builder().tags(ImmutableList.of()); + } + + /** + * Creates a new builder with a copy of this strategy's values. + * + * @return a new {@link Builder} instance. + */ + public abstract Builder toBuilder(); + + /** Builder for {@link ReasoningStrategy}. */ + @AutoValue.Builder + public abstract static class Builder { + + @JsonCreator + static Builder create() { + return new AutoValue_ReasoningStrategy.Builder().tags(ImmutableList.of()); + } + + /** Sets the unique identifier for this strategy. */ + @JsonProperty("id") + public abstract Builder id(String id); + + /** Sets the name of this strategy. */ + @JsonProperty("name") + public abstract Builder name(String name); + + /** Sets the problem pattern description. */ + @JsonProperty("problemPattern") + public abstract Builder problemPattern(String problemPattern); + + /** Sets the ordered list of reasoning steps. */ + @JsonProperty("steps") + public abstract Builder steps(ImmutableList steps); + + /** Sets the tags for categorization. */ + @JsonProperty("tags") + public abstract Builder tags(ImmutableList tags); + + /** Sets the creation timestamp as an ISO 8601 string. */ + @JsonProperty("createdAt") + public abstract Builder createdAt(@Nullable String createdAt); + + /** + * Convenience method to set the creation timestamp from an {@link Instant}. + * + * @param instant The timestamp as an Instant object. + */ + public Builder createdAt(Instant instant) { + return createdAt(instant.toString()); + } + + /** Builds the immutable {@link ReasoningStrategy} object. */ + public abstract ReasoningStrategy build(); + } +} diff --git a/core/src/main/java/com/google/adk/reasoning/ReasoningTrace.java b/core/src/main/java/com/google/adk/reasoning/ReasoningTrace.java new file mode 100644 index 000000000..3825f7d34 --- /dev/null +++ b/core/src/main/java/com/google/adk/reasoning/ReasoningTrace.java @@ -0,0 +1,139 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.reasoning; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.annotation.JsonDeserialize; +import com.google.auto.value.AutoValue; +import com.google.common.collect.ImmutableList; +import java.time.Instant; +import javax.annotation.Nullable; + +/** + * Represents a raw reasoning trace captured from a task execution. + * + *

A reasoning trace captures the input, output, and intermediate reasoning steps from a + * successful task execution. Traces can be distilled into reusable {@link ReasoningStrategy} + * objects. + * + *

Based on the ReasoningBank paper (arXiv:2509.25140). + */ +@AutoValue +@JsonDeserialize(builder = ReasoningTrace.Builder.class) +public abstract class ReasoningTrace { + + /** Returns the unique identifier for this trace. */ + @JsonProperty("id") + public abstract String id(); + + /** Returns the original task or prompt that was executed. */ + @JsonProperty("task") + public abstract String task(); + + /** Returns the final output or response from the task execution. */ + @JsonProperty("output") + public abstract String output(); + + /** + * Returns the intermediate reasoning steps captured during execution. + * + *

These are the raw chain-of-thought steps before distillation. + */ + @JsonProperty("reasoningSteps") + public abstract ImmutableList reasoningSteps(); + + /** Returns whether the task execution was successful. */ + @JsonProperty("successful") + public abstract boolean successful(); + + /** Returns the timestamp when this trace was captured. */ + @Nullable + @JsonProperty("capturedAt") + public abstract String capturedAt(); + + /** Returns optional metadata about the execution context. */ + @Nullable + @JsonProperty("metadata") + public abstract String metadata(); + + /** Returns a new builder for creating a {@link ReasoningTrace}. */ + public static Builder builder() { + return new AutoValue_ReasoningTrace.Builder() + .reasoningSteps(ImmutableList.of()) + .successful(true); + } + + /** + * Creates a new builder with a copy of this trace's values. + * + * @return a new {@link Builder} instance. + */ + public abstract Builder toBuilder(); + + /** Builder for {@link ReasoningTrace}. */ + @AutoValue.Builder + public abstract static class Builder { + + @JsonCreator + static Builder create() { + return new AutoValue_ReasoningTrace.Builder() + .reasoningSteps(ImmutableList.of()) + .successful(true); + } + + /** Sets the unique identifier for this trace. */ + @JsonProperty("id") + public abstract Builder id(String id); + + /** Sets the original task or prompt. */ + @JsonProperty("task") + public abstract Builder task(String task); + + /** Sets the final output from the task execution. */ + @JsonProperty("output") + public abstract Builder output(String output); + + /** Sets the intermediate reasoning steps. */ + @JsonProperty("reasoningSteps") + public abstract Builder reasoningSteps(ImmutableList reasoningSteps); + + /** Sets whether the task execution was successful. */ + @JsonProperty("successful") + public abstract Builder successful(boolean successful); + + /** Sets the capture timestamp as an ISO 8601 string. */ + @JsonProperty("capturedAt") + public abstract Builder capturedAt(@Nullable String capturedAt); + + /** + * Convenience method to set the capture timestamp from an {@link Instant}. + * + * @param instant The timestamp as an Instant object. + */ + public Builder capturedAt(Instant instant) { + return capturedAt(instant.toString()); + } + + /** Sets optional metadata about the execution context. */ + @JsonProperty("metadata") + public abstract Builder metadata(@Nullable String metadata); + + /** Builds the immutable {@link ReasoningTrace} object. */ + public abstract ReasoningTrace build(); + } +} diff --git a/core/src/main/java/com/google/adk/reasoning/SearchReasoningResponse.java b/core/src/main/java/com/google/adk/reasoning/SearchReasoningResponse.java new file mode 100644 index 000000000..07d284e19 --- /dev/null +++ b/core/src/main/java/com/google/adk/reasoning/SearchReasoningResponse.java @@ -0,0 +1,49 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.reasoning; + +import com.google.auto.value.AutoValue; +import com.google.common.collect.ImmutableList; +import java.util.List; + +/** Represents the response from a reasoning strategy search. */ +@AutoValue +public abstract class SearchReasoningResponse { + + /** Returns a list of reasoning strategies that match the search query. */ + public abstract ImmutableList strategies(); + + /** Creates a new builder for {@link SearchReasoningResponse}. */ + public static Builder builder() { + return new AutoValue_SearchReasoningResponse.Builder().setStrategies(ImmutableList.of()); + } + + /** Builder for {@link SearchReasoningResponse}. */ + @AutoValue.Builder + public abstract static class Builder { + + abstract Builder setStrategies(ImmutableList strategies); + + /** Sets the list of reasoning strategies using a list. */ + public Builder setStrategies(List strategies) { + return setStrategies(ImmutableList.copyOf(strategies)); + } + + /** Builds the immutable {@link SearchReasoningResponse} object. */ + public abstract SearchReasoningResponse build(); + } +} diff --git a/core/src/main/java/com/google/adk/tools/LoadReasoningStrategyResponse.java b/core/src/main/java/com/google/adk/tools/LoadReasoningStrategyResponse.java new file mode 100644 index 000000000..2d1cb4ed4 --- /dev/null +++ b/core/src/main/java/com/google/adk/tools/LoadReasoningStrategyResponse.java @@ -0,0 +1,25 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.tools; + +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.adk.reasoning.ReasoningStrategy; +import java.util.List; + +/** The response from a load reasoning strategy tool invocation. */ +public record LoadReasoningStrategyResponse( + @JsonProperty("strategies") List strategies) {} diff --git a/core/src/main/java/com/google/adk/tools/LoadReasoningStrategyTool.java b/core/src/main/java/com/google/adk/tools/LoadReasoningStrategyTool.java new file mode 100644 index 000000000..c65e5ef2d --- /dev/null +++ b/core/src/main/java/com/google/adk/tools/LoadReasoningStrategyTool.java @@ -0,0 +1,84 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.tools; + +import com.google.adk.models.LlmRequest; +import com.google.common.collect.ImmutableList; +import io.reactivex.rxjava3.core.Completable; +import io.reactivex.rxjava3.core.Single; +import java.lang.reflect.Method; + +/** + * A tool that loads reasoning strategies for the current task. + * + *

This tool allows agents to retrieve relevant reasoning strategies from the ReasoningBank based + * on a query describing the current task. The retrieved strategies provide structured + * problem-solving approaches that can guide the agent's reasoning. + * + *

Based on the ReasoningBank paper (arXiv:2509.25140). + */ +public class LoadReasoningStrategyTool extends FunctionTool { + + private static Method getLoadReasoningStrategyMethod() { + try { + return LoadReasoningStrategyTool.class.getMethod( + "loadReasoningStrategy", String.class, ToolContext.class); + } catch (NoSuchMethodException e) { + throw new IllegalStateException("Failed to find loadReasoningStrategy method.", e); + } + } + + public LoadReasoningStrategyTool() { + super( + /* instance= */ null, + getLoadReasoningStrategyMethod(), + /* isLongRunning= */ false, + /* requireConfirmation= */ false); + } + + /** + * Loads reasoning strategies that match the given query. + * + * @param query A description of the task or problem to find strategies for. + * @param toolContext The tool context providing access to the reasoning bank service. + * @return A response containing matching reasoning strategies. + */ + public static Single loadReasoningStrategy( + @Annotations.Schema(name = "query", description = "A description of the task or problem") + String query, + ToolContext toolContext) { + return toolContext + .searchReasoningStrategies(query) + .map(response -> new LoadReasoningStrategyResponse(response.strategies())); + } + + @Override + public Completable processLlmRequest( + LlmRequest.Builder llmRequestBuilder, ToolContext toolContext) { + return super.processLlmRequest(llmRequestBuilder, toolContext) + .doOnComplete( + () -> + llmRequestBuilder.appendInstructions( + ImmutableList.of( +""" +You have access to a reasoning bank containing proven problem-solving strategies. +When facing a complex task, you can call loadReasoningStrategy with a description +of your task to retrieve relevant reasoning approaches. Each strategy includes +problem patterns it addresses and ordered reasoning steps to follow. +"""))); + } +} diff --git a/core/src/main/java/com/google/adk/tools/ToolContext.java b/core/src/main/java/com/google/adk/tools/ToolContext.java index 50a27bd85..fa6b49095 100644 --- a/core/src/main/java/com/google/adk/tools/ToolContext.java +++ b/core/src/main/java/com/google/adk/tools/ToolContext.java @@ -20,6 +20,7 @@ import com.google.adk.agents.InvocationContext; import com.google.adk.events.EventActions; import com.google.adk.memory.SearchMemoryResponse; +import com.google.adk.reasoning.SearchReasoningResponse; import com.google.errorprone.annotations.CanIgnoreReturnValue; import io.reactivex.rxjava3.core.Single; import java.util.Optional; @@ -116,6 +117,16 @@ public Single searchMemory(String query) { invocationContext.session().appName(), invocationContext.session().userId(), query); } + /** Searches for reasoning strategies matching the given query. */ + public Single searchReasoningStrategies(String query) { + if (invocationContext.reasoningBankService() == null) { + throw new IllegalStateException("Reasoning bank service is not initialized."); + } + return invocationContext + .reasoningBankService() + .searchStrategies(invocationContext.session().appName(), query); + } + public static Builder builder(InvocationContext invocationContext) { return new Builder(invocationContext); } diff --git a/core/src/test/java/com/google/adk/reasoning/InMemoryReasoningBankServiceTest.java b/core/src/test/java/com/google/adk/reasoning/InMemoryReasoningBankServiceTest.java new file mode 100644 index 000000000..9cf784c84 --- /dev/null +++ b/core/src/test/java/com/google/adk/reasoning/InMemoryReasoningBankServiceTest.java @@ -0,0 +1,251 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.reasoning; + +import static com.google.common.truth.Truth.assertThat; + +import com.google.common.collect.ImmutableList; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link InMemoryReasoningBankService}. */ +@RunWith(JUnit4.class) +public final class InMemoryReasoningBankServiceTest { + + private static final String APP_NAME = "test-app"; + + private InMemoryReasoningBankService service; + + @Before + public void setUp() { + service = new InMemoryReasoningBankService(); + } + + @Test + public void searchStrategies_emptyBank_returnsEmpty() { + SearchReasoningResponse response = + service.searchStrategies(APP_NAME, "math problem").blockingGet(); + + assertThat(response.strategies()).isEmpty(); + } + + @Test + public void storeAndSearch_findsMatchingStrategy() { + ReasoningStrategy strategy = + ReasoningStrategy.builder() + .id("strategy-1") + .name("Math Problem Solving") + .problemPattern("Mathematical calculations involving algebra") + .steps(ImmutableList.of("Identify unknowns", "Set up equations", "Solve")) + .tags(ImmutableList.of("math", "algebra")) + .build(); + + service.storeStrategy(APP_NAME, strategy).blockingAwait(); + + SearchReasoningResponse response = + service.searchStrategies(APP_NAME, "algebra problem").blockingGet(); + + assertThat(response.strategies()).hasSize(1); + assertThat(response.strategies().get(0).id()).isEqualTo("strategy-1"); + } + + @Test + public void searchStrategies_noMatch_returnsEmpty() { + ReasoningStrategy strategy = + ReasoningStrategy.builder() + .id("strategy-1") + .name("Math Problem Solving") + .problemPattern("Mathematical calculations") + .steps(ImmutableList.of("Step 1")) + .build(); + + service.storeStrategy(APP_NAME, strategy).blockingAwait(); + + SearchReasoningResponse response = + service.searchStrategies(APP_NAME, "biology chemistry").blockingGet(); + + assertThat(response.strategies()).isEmpty(); + } + + @Test + public void searchStrategies_matchesByName() { + ReasoningStrategy strategy = + ReasoningStrategy.builder() + .id("strategy-1") + .name("Debugging Code") + .problemPattern("Test pattern") + .steps(ImmutableList.of("Step 1")) + .build(); + + service.storeStrategy(APP_NAME, strategy).blockingAwait(); + + SearchReasoningResponse response = + service.searchStrategies(APP_NAME, "code debugging").blockingGet(); + + assertThat(response.strategies()).hasSize(1); + } + + @Test + public void searchStrategies_matchesByTags() { + ReasoningStrategy strategy = + ReasoningStrategy.builder() + .id("strategy-1") + .name("Test Strategy") + .problemPattern("Test pattern") + .steps(ImmutableList.of("Step 1")) + .tags(ImmutableList.of("python", "programming")) + .build(); + + service.storeStrategy(APP_NAME, strategy).blockingAwait(); + + SearchReasoningResponse response = service.searchStrategies(APP_NAME, "python").blockingGet(); + + assertThat(response.strategies()).hasSize(1); + } + + @Test + public void searchStrategies_rankedByRelevance() { + // Strategy with pattern match (highest weight) + ReasoningStrategy patternMatch = + ReasoningStrategy.builder() + .id("pattern-match") + .name("Other Name") + .problemPattern("algorithm optimization problems") + .steps(ImmutableList.of("Step 1")) + .build(); + + // Strategy with name match (medium weight) + ReasoningStrategy nameMatch = + ReasoningStrategy.builder() + .id("name-match") + .name("Algorithm Design") + .problemPattern("Other pattern") + .steps(ImmutableList.of("Step 1")) + .build(); + + service.storeStrategy(APP_NAME, nameMatch).blockingAwait(); + service.storeStrategy(APP_NAME, patternMatch).blockingAwait(); + + SearchReasoningResponse response = + service.searchStrategies(APP_NAME, "algorithm").blockingGet(); + + assertThat(response.strategies()).hasSize(2); + // Pattern match should rank higher than name match + assertThat(response.strategies().get(0).id()).isEqualTo("pattern-match"); + } + + @Test + public void searchStrategies_respectsMaxResults() { + for (int i = 0; i < 10; i++) { + ReasoningStrategy strategy = + ReasoningStrategy.builder() + .id("strategy-" + i) + .name("Test Strategy " + i) + .problemPattern("Common problem pattern") + .steps(ImmutableList.of("Step 1")) + .build(); + service.storeStrategy(APP_NAME, strategy).blockingAwait(); + } + + SearchReasoningResponse response = + service.searchStrategies(APP_NAME, "problem pattern", 3).blockingGet(); + + assertThat(response.strategies()).hasSize(3); + } + + @Test + public void searchStrategies_differentApps_isolated() { + ReasoningStrategy strategy1 = + ReasoningStrategy.builder() + .id("app1-strategy") + .name("Test Strategy") + .problemPattern("Test pattern") + .steps(ImmutableList.of("Step 1")) + .build(); + + ReasoningStrategy strategy2 = + ReasoningStrategy.builder() + .id("app2-strategy") + .name("Test Strategy") + .problemPattern("Test pattern") + .steps(ImmutableList.of("Step 1")) + .build(); + + service.storeStrategy("app1", strategy1).blockingAwait(); + service.storeStrategy("app2", strategy2).blockingAwait(); + + SearchReasoningResponse response1 = service.searchStrategies("app1", "test").blockingGet(); + SearchReasoningResponse response2 = service.searchStrategies("app2", "test").blockingGet(); + + assertThat(response1.strategies()).hasSize(1); + assertThat(response1.strategies().get(0).id()).isEqualTo("app1-strategy"); + + assertThat(response2.strategies()).hasSize(1); + assertThat(response2.strategies().get(0).id()).isEqualTo("app2-strategy"); + } + + @Test + public void storeTrace_tracesAreStored() { + ReasoningTrace trace = + ReasoningTrace.builder() + .id("trace-1") + .task("Test task") + .output("Test output") + .reasoningSteps(ImmutableList.of("Step 1")) + .successful(true) + .build(); + + // Should complete without error + service.storeTrace(APP_NAME, trace).blockingAwait(); + } + + @Test + public void searchStrategies_emptyQuery_returnsEmpty() { + ReasoningStrategy strategy = + ReasoningStrategy.builder() + .id("strategy-1") + .name("Test Strategy") + .problemPattern("Test pattern") + .steps(ImmutableList.of("Step 1")) + .build(); + + service.storeStrategy(APP_NAME, strategy).blockingAwait(); + + SearchReasoningResponse response = service.searchStrategies(APP_NAME, "").blockingGet(); + + assertThat(response.strategies()).isEmpty(); + } + + @Test + public void searchStrategies_caseInsensitive() { + ReasoningStrategy strategy = + ReasoningStrategy.builder() + .id("strategy-1") + .name("Test Strategy") + .problemPattern("UPPERCASE pattern") + .steps(ImmutableList.of("Step 1")) + .build(); + + service.storeStrategy(APP_NAME, strategy).blockingAwait(); + + SearchReasoningResponse response = + service.searchStrategies(APP_NAME, "uppercase").blockingGet(); + + assertThat(response.strategies()).hasSize(1); + } +} diff --git a/core/src/test/java/com/google/adk/reasoning/ReasoningStrategyTest.java b/core/src/test/java/com/google/adk/reasoning/ReasoningStrategyTest.java new file mode 100644 index 000000000..e8a4db006 --- /dev/null +++ b/core/src/test/java/com/google/adk/reasoning/ReasoningStrategyTest.java @@ -0,0 +1,99 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.reasoning; + +import static com.google.common.truth.Truth.assertThat; + +import com.google.common.collect.ImmutableList; +import java.time.Instant; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link ReasoningStrategy}. */ +@RunWith(JUnit4.class) +public final class ReasoningStrategyTest { + + @Test + public void builder_createsValidStrategy() { + ReasoningStrategy strategy = + ReasoningStrategy.builder() + .id("strategy-1") + .name("Math Problem Solving") + .problemPattern("Mathematical word problems involving rates") + .steps( + ImmutableList.of( + "Identify the known quantities", + "Identify what needs to be found", + "Set up equations", + "Solve and verify")) + .tags(ImmutableList.of("math", "rates")) + .createdAt("2025-01-05T10:00:00Z") + .build(); + + assertThat(strategy.id()).isEqualTo("strategy-1"); + assertThat(strategy.name()).isEqualTo("Math Problem Solving"); + assertThat(strategy.problemPattern()).isEqualTo("Mathematical word problems involving rates"); + assertThat(strategy.steps()).hasSize(4); + assertThat(strategy.tags()).containsExactly("math", "rates"); + assertThat(strategy.createdAt()).isEqualTo("2025-01-05T10:00:00Z"); + } + + @Test + public void builder_defaultTagsIsEmpty() { + ReasoningStrategy strategy = + ReasoningStrategy.builder() + .id("strategy-1") + .name("Test Strategy") + .problemPattern("Test pattern") + .steps(ImmutableList.of("Step 1")) + .build(); + + assertThat(strategy.tags()).isEmpty(); + } + + @Test + public void builder_createdAtWithInstant() { + Instant now = Instant.parse("2025-01-05T12:00:00Z"); + ReasoningStrategy strategy = + ReasoningStrategy.builder() + .id("strategy-1") + .name("Test Strategy") + .problemPattern("Test pattern") + .steps(ImmutableList.of("Step 1")) + .createdAt(now) + .build(); + + assertThat(strategy.createdAt()).isEqualTo("2025-01-05T12:00:00Z"); + } + + @Test + public void toBuilder_createsCopy() { + ReasoningStrategy original = + ReasoningStrategy.builder() + .id("strategy-1") + .name("Original") + .problemPattern("Test pattern") + .steps(ImmutableList.of("Step 1")) + .build(); + + ReasoningStrategy modified = original.toBuilder().name("Modified").build(); + + assertThat(original.name()).isEqualTo("Original"); + assertThat(modified.name()).isEqualTo("Modified"); + assertThat(modified.id()).isEqualTo(original.id()); + } +} diff --git a/core/src/test/java/com/google/adk/reasoning/ReasoningTraceTest.java b/core/src/test/java/com/google/adk/reasoning/ReasoningTraceTest.java new file mode 100644 index 000000000..ee45e297f --- /dev/null +++ b/core/src/test/java/com/google/adk/reasoning/ReasoningTraceTest.java @@ -0,0 +1,104 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.reasoning; + +import static com.google.common.truth.Truth.assertThat; + +import com.google.common.collect.ImmutableList; +import java.time.Instant; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link ReasoningTrace}. */ +@RunWith(JUnit4.class) +public final class ReasoningTraceTest { + + @Test + public void builder_createsValidTrace() { + ReasoningTrace trace = + ReasoningTrace.builder() + .id("trace-1") + .task("Calculate the area of a circle with radius 5") + .output("The area is 78.54 square units") + .reasoningSteps( + ImmutableList.of( + "Recall the formula: A = πr²", + "Substitute r = 5", + "Calculate: A = π × 25 = 78.54")) + .successful(true) + .capturedAt("2025-01-05T10:00:00Z") + .metadata("source=test") + .build(); + + assertThat(trace.id()).isEqualTo("trace-1"); + assertThat(trace.task()).isEqualTo("Calculate the area of a circle with radius 5"); + assertThat(trace.output()).isEqualTo("The area is 78.54 square units"); + assertThat(trace.reasoningSteps()).hasSize(3); + assertThat(trace.successful()).isTrue(); + assertThat(trace.capturedAt()).isEqualTo("2025-01-05T10:00:00Z"); + assertThat(trace.metadata()).isEqualTo("source=test"); + } + + @Test + public void builder_defaultsToSuccessful() { + ReasoningTrace trace = + ReasoningTrace.builder().id("trace-1").task("Test task").output("Test output").build(); + + assertThat(trace.successful()).isTrue(); + } + + @Test + public void builder_defaultReasoningStepsIsEmpty() { + ReasoningTrace trace = + ReasoningTrace.builder().id("trace-1").task("Test task").output("Test output").build(); + + assertThat(trace.reasoningSteps()).isEmpty(); + } + + @Test + public void builder_capturedAtWithInstant() { + Instant now = Instant.parse("2025-01-05T12:00:00Z"); + ReasoningTrace trace = + ReasoningTrace.builder() + .id("trace-1") + .task("Test task") + .output("Test output") + .capturedAt(now) + .build(); + + assertThat(trace.capturedAt()).isEqualTo("2025-01-05T12:00:00Z"); + } + + @Test + public void toBuilder_createsCopy() { + ReasoningTrace original = + ReasoningTrace.builder() + .id("trace-1") + .task("Original task") + .output("Original output") + .successful(true) + .build(); + + ReasoningTrace modified = original.toBuilder().task("Modified task").successful(false).build(); + + assertThat(original.task()).isEqualTo("Original task"); + assertThat(original.successful()).isTrue(); + assertThat(modified.task()).isEqualTo("Modified task"); + assertThat(modified.successful()).isFalse(); + assertThat(modified.id()).isEqualTo(original.id()); + } +}