diff --git a/core/src/main/java/com/google/adk/agents/BaseAgent.java b/core/src/main/java/com/google/adk/agents/BaseAgent.java index f472cba6..e9a0494a 100644 --- a/core/src/main/java/com/google/adk/agents/BaseAgent.java +++ b/core/src/main/java/com/google/adk/agents/BaseAgent.java @@ -16,6 +16,8 @@ package com.google.adk.agents; +import static com.google.common.collect.ImmutableList.toImmutableList; + import com.google.adk.Telemetry; import com.google.adk.agents.Callbacks.AfterAgentCallback; import com.google.adk.agents.Callbacks.BeforeAgentCallback; @@ -57,7 +59,8 @@ public abstract class BaseAgent { private final List subAgents; - protected final CallbackPlugin callbackPlugin; + private final Optional> beforeAgentCallback; + private final Optional> afterAgentCallback; /** * Creates a new BaseAgent. @@ -76,32 +79,12 @@ public BaseAgent( List subAgents, @Nullable List beforeAgentCallback, @Nullable List afterAgentCallback) { - this( - name, - description, - subAgents, - createCallbackPlugin(beforeAgentCallback, afterAgentCallback)); - } - - /** - * Creates a new BaseAgent. - * - * @param name Unique agent name. Cannot be "user" (reserved). - * @param description Agent purpose. - * @param subAgents Agents managed by this agent. - * @param callbackPlugin The callback plugin for this agent. - */ - protected BaseAgent( - String name, - String description, - List subAgents, - CallbackPlugin callbackPlugin) { this.name = name; this.description = description; this.parentAgent = null; this.subAgents = subAgents != null ? subAgents : ImmutableList.of(); - this.callbackPlugin = - callbackPlugin == null ? CallbackPlugin.builder().build() : callbackPlugin; + this.beforeAgentCallback = Optional.ofNullable(beforeAgentCallback); + this.afterAgentCallback = Optional.ofNullable(afterAgentCallback); // Establish parent relationships for all sub-agents if needed. for (BaseAgent subAgent : this.subAgents) { @@ -109,18 +92,6 @@ protected BaseAgent( } } - /** Creates a {@link CallbackPlugin} from lists of before and after agent callbacks. */ - private static CallbackPlugin createCallbackPlugin( - @Nullable List beforeAgentCallbacks, - @Nullable List afterAgentCallbacks) { - CallbackPlugin.Builder builder = CallbackPlugin.builder(); - Stream.ofNullable(beforeAgentCallbacks).flatMap(List::stream).forEach(builder::addCallback); - Optional.ofNullable(afterAgentCallbacks).stream() - .flatMap(List::stream) - .forEach(builder::addCallback); - return builder.build(); - } - /** * Gets the agent's unique name. * @@ -201,15 +172,11 @@ public List subAgents() { } public Optional> beforeAgentCallback() { - return Optional.of(callbackPlugin.getBeforeAgentCallback()); + return beforeAgentCallback; } public Optional> afterAgentCallback() { - return Optional.of(callbackPlugin.getAfterAgentCallback()); - } - - public Plugin getPlugin() { - return callbackPlugin; + return afterAgentCallback; } /** @@ -252,11 +219,11 @@ public Flowable runAsync(InvocationContext parentContext) { spanContext, span, () -> - processAgentCallbackResult( - ctx -> invocationContext.combinedPlugin().beforeAgentCallback(this, ctx), + callCallback( + beforeCallbacksToFunctions( + invocationContext.pluginManager(), + beforeAgentCallback.orElse(ImmutableList.of())), invocationContext) - .map(Optional::of) - .switchIfEmpty(Single.just(Optional.empty())) .flatMapPublisher( beforeEventOpt -> { if (invocationContext.endInvocation()) { @@ -269,14 +236,11 @@ public Flowable runAsync(InvocationContext parentContext) { Flowable afterEvents = Flowable.defer( () -> - processAgentCallbackResult( - ctx -> - invocationContext - .combinedPlugin() - .afterAgentCallback(this, ctx), + callCallback( + afterCallbacksToFunctions( + invocationContext.pluginManager(), + afterAgentCallback.orElse(ImmutableList.of())), invocationContext) - .map(Optional::of) - .switchIfEmpty(Single.just(Optional.empty())) .flatMapPublisher(Flowable::fromOptional)); return Flowable.concat(beforeEvents, mainEvents, afterEvents); @@ -285,32 +249,76 @@ public Flowable runAsync(InvocationContext parentContext) { } /** - * Processes the result of an agent callback, creating an {@link Event} if necessary. + * Converts before-agent callbacks to functions. + * + * @param callbacks Before-agent callbacks. + * @return callback functions. + */ + private ImmutableList>> beforeCallbacksToFunctions( + Plugin pluginManager, List callbacks) { + return Stream.concat( + Stream.of(ctx -> pluginManager.beforeAgentCallback(this, ctx)), + callbacks.stream() + .map(callback -> (Function>) callback::call)) + .collect(toImmutableList()); + } + + /** + * Converts after-agent callbacks to functions. + * + * @param callbacks After-agent callbacks. + * @return callback functions. + */ + private ImmutableList>> afterCallbacksToFunctions( + Plugin pluginManager, List callbacks) { + return Stream.concat( + Stream.of(ctx -> pluginManager.afterAgentCallback(this, ctx)), + callbacks.stream() + .map(callback -> (Function>) callback::call)) + .collect(toImmutableList()); + } + + /** + * Calls agent callbacks and returns the first produced event, if any. * - * @param agentCallback The callback function. - * @param invocationContext The current invocation context. - * @return A {@link Maybe} emitting an {@link Event} if one is produced, or empty otherwise. + * @param agentCallbacks Callback functions. + * @param invocationContext Current invocation context. + * @return single emitting first event, or empty if none. */ - private Maybe processAgentCallbackResult( - Function> agentCallback, + private Single> callCallback( + List>> agentCallbacks, InvocationContext invocationContext) { - var callbackContext = new CallbackContext(invocationContext, /* eventActions= */ null); - return agentCallback - .apply(callbackContext) - .map( - content -> { - invocationContext.setEndInvocation(true); - return Event.builder() - .id(Event.generateEventId()) - .invocationId(invocationContext.invocationId()) - .author(name()) - .branch(invocationContext.branch()) - .actions(callbackContext.eventActions()) - .content(content) - .build(); + if (agentCallbacks == null || agentCallbacks.isEmpty()) { + return Single.just(Optional.empty()); + } + + CallbackContext callbackContext = + new CallbackContext(invocationContext, /* eventActions= */ null); + + return Flowable.fromIterable(agentCallbacks) + .concatMap( + callback -> { + Maybe maybeContent = callback.apply(callbackContext); + + return maybeContent + .map( + content -> { + invocationContext.setEndInvocation(true); + return Optional.of( + Event.builder() + .id(Event.generateEventId()) + .invocationId(invocationContext.invocationId()) + .author(name()) + .branch(invocationContext.branch()) + .actions(callbackContext.eventActions()) + .content(content) + .build()); + }) + .toFlowable(); }) + .firstElement() .switchIfEmpty( - Maybe.defer( + Single.defer( () -> { if (callbackContext.state().hasDelta()) { Event.Builder eventBuilder = @@ -321,9 +329,9 @@ private Maybe processAgentCallbackResult( .branch(invocationContext.branch()) .actions(callbackContext.eventActions()); - return Maybe.just(eventBuilder.build()); + return Single.just(Optional.of(eventBuilder.build())); } else { - return Maybe.empty(); + return Single.just(Optional.empty()); } })); } @@ -391,11 +399,8 @@ public abstract static class Builder> { protected String name; protected String description; protected ImmutableList subAgents; - protected final CallbackPlugin.Builder callbackPluginBuilder = CallbackPlugin.builder(); - - protected CallbackPlugin.Builder callbackPluginBuilder() { - return callbackPluginBuilder; - } + protected ImmutableList beforeAgentCallback; + protected ImmutableList afterAgentCallback; /** This is a safe cast to the concrete builder type. */ @SuppressWarnings("unchecked") @@ -429,25 +434,25 @@ public B subAgents(BaseAgent... subAgents) { @CanIgnoreReturnValue public B beforeAgentCallback(BeforeAgentCallback beforeAgentCallback) { - callbackPluginBuilder.addBeforeAgentCallback(beforeAgentCallback); + this.beforeAgentCallback = ImmutableList.of(beforeAgentCallback); return self(); } @CanIgnoreReturnValue public B beforeAgentCallback(List beforeAgentCallback) { - beforeAgentCallback.forEach(callbackPluginBuilder::addCallback); + this.beforeAgentCallback = CallbackUtil.getBeforeAgentCallbacks(beforeAgentCallback); return self(); } @CanIgnoreReturnValue public B afterAgentCallback(AfterAgentCallback afterAgentCallback) { - callbackPluginBuilder.addAfterAgentCallback(afterAgentCallback); + this.afterAgentCallback = ImmutableList.of(afterAgentCallback); return self(); } @CanIgnoreReturnValue public B afterAgentCallback(List afterAgentCallback) { - afterAgentCallback.forEach(callbackPluginBuilder::addCallback); + this.afterAgentCallback = CallbackUtil.getAfterAgentCallbacks(afterAgentCallback); return self(); } diff --git a/core/src/main/java/com/google/adk/agents/CallbackPlugin.java b/core/src/main/java/com/google/adk/agents/CallbackPlugin.java deleted file mode 100644 index 10e260fd..00000000 --- a/core/src/main/java/com/google/adk/agents/CallbackPlugin.java +++ /dev/null @@ -1,347 +0,0 @@ -/* - * Copyright 2026 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.google.adk.agents; - -import com.google.adk.agents.Callbacks.AfterAgentCallback; -import com.google.adk.agents.Callbacks.AfterAgentCallbackBase; -import com.google.adk.agents.Callbacks.AfterAgentCallbackSync; -import com.google.adk.agents.Callbacks.AfterModelCallback; -import com.google.adk.agents.Callbacks.AfterModelCallbackBase; -import com.google.adk.agents.Callbacks.AfterModelCallbackSync; -import com.google.adk.agents.Callbacks.AfterToolCallback; -import com.google.adk.agents.Callbacks.AfterToolCallbackBase; -import com.google.adk.agents.Callbacks.AfterToolCallbackSync; -import com.google.adk.agents.Callbacks.BeforeAgentCallback; -import com.google.adk.agents.Callbacks.BeforeAgentCallbackBase; -import com.google.adk.agents.Callbacks.BeforeAgentCallbackSync; -import com.google.adk.agents.Callbacks.BeforeModelCallback; -import com.google.adk.agents.Callbacks.BeforeModelCallbackBase; -import com.google.adk.agents.Callbacks.BeforeModelCallbackSync; -import com.google.adk.agents.Callbacks.BeforeToolCallback; -import com.google.adk.agents.Callbacks.BeforeToolCallbackBase; -import com.google.adk.agents.Callbacks.BeforeToolCallbackSync; -import com.google.adk.models.LlmRequest; -import com.google.adk.models.LlmResponse; -import com.google.adk.plugins.BasePlugin; -import com.google.adk.plugins.PluginManager; -import com.google.adk.tools.BaseTool; -import com.google.adk.tools.ToolContext; -import com.google.common.collect.ArrayListMultimap; -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableListMultimap; -import com.google.common.collect.ListMultimap; -import com.google.errorprone.annotations.CanIgnoreReturnValue; -import com.google.genai.types.Content; -import io.reactivex.rxjava3.core.Maybe; -import java.util.Map; -import java.util.concurrent.atomic.AtomicInteger; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** A plugin that wraps callbacks and exposes them as a plugin. */ -public class CallbackPlugin extends PluginManager { - - private static final Logger logger = LoggerFactory.getLogger(CallbackPlugin.class); - - private final ImmutableListMultimap, Object> callbacks; - - private CallbackPlugin( - ImmutableList plugins, - ImmutableListMultimap, Object> callbacks) { - super(plugins); - this.callbacks = callbacks; - } - - @Override - public String getName() { - return "CallbackPlugin"; - } - - @SuppressWarnings("unchecked") // The builder ensures that the type is correct. - private ImmutableList getCallbacks(Class type) { - return (ImmutableList) callbacks.get(type); - } - - public ImmutableList getBeforeAgentCallback() { - return getCallbacks(Callbacks.BeforeAgentCallback.class); - } - - public ImmutableList getAfterAgentCallback() { - return getCallbacks(Callbacks.AfterAgentCallback.class); - } - - public ImmutableList getBeforeModelCallback() { - return getCallbacks(Callbacks.BeforeModelCallback.class); - } - - public ImmutableList getAfterModelCallback() { - return getCallbacks(Callbacks.AfterModelCallback.class); - } - - public ImmutableList getBeforeToolCallback() { - return getCallbacks(Callbacks.BeforeToolCallback.class); - } - - public ImmutableList getAfterToolCallback() { - return getCallbacks(Callbacks.AfterToolCallback.class); - } - - public static Builder builder() { - return new Builder(); - } - - /** Builder for {@link CallbackPlugin}. */ - public static class Builder { - // Ensures a unique name for each callback. - private static final AtomicInteger callbackId = new AtomicInteger(0); - - private ImmutableList.Builder plugins = ImmutableList.builder(); - private final ListMultimap, Object> callbacks = ArrayListMultimap.create(); - - Builder() {} - - @CanIgnoreReturnValue - public Builder addBeforeAgentCallback(Callbacks.BeforeAgentCallback callback) { - callbacks.put(Callbacks.BeforeAgentCallback.class, callback); - plugins.add( - new BasePlugin("BeforeAgentCallback_" + callbackId.getAndIncrement()) { - @Override - public Maybe beforeAgentCallback( - BaseAgent agent, CallbackContext callbackContext) { - return callback.call(callbackContext); - } - }); - return this; - } - - @CanIgnoreReturnValue - public Builder addBeforeAgentCallbackSync(Callbacks.BeforeAgentCallbackSync callback) { - return addBeforeAgentCallback( - callbackContext -> Maybe.fromOptional(callback.call(callbackContext))); - } - - @CanIgnoreReturnValue - public Builder addAfterAgentCallback(Callbacks.AfterAgentCallback callback) { - callbacks.put(Callbacks.AfterAgentCallback.class, callback); - plugins.add( - new BasePlugin("AfterAgentCallback_" + callbackId.getAndIncrement()) { - @Override - public Maybe afterAgentCallback( - BaseAgent agent, CallbackContext callbackContext) { - return callback.call(callbackContext); - } - }); - return this; - } - - @CanIgnoreReturnValue - public Builder addAfterAgentCallbackSync(Callbacks.AfterAgentCallbackSync callback) { - return addAfterAgentCallback( - callbackContext -> Maybe.fromOptional(callback.call(callbackContext))); - } - - @CanIgnoreReturnValue - public Builder addBeforeModelCallback(Callbacks.BeforeModelCallback callback) { - callbacks.put(Callbacks.BeforeModelCallback.class, callback); - plugins.add( - new BasePlugin("BeforeModelCallback_" + callbackId.getAndIncrement()) { - @Override - public Maybe beforeModelCallback( - CallbackContext callbackContext, LlmRequest.Builder llmRequest) { - return callback.call(callbackContext, llmRequest); - } - }); - return this; - } - - @CanIgnoreReturnValue - public Builder addBeforeModelCallbackSync(Callbacks.BeforeModelCallbackSync callback) { - return addBeforeModelCallback( - (callbackContext, llmRequest) -> - Maybe.fromOptional(callback.call(callbackContext, llmRequest))); - } - - // (b/476510024): Temporary workaround for ces - @CanIgnoreReturnValue - public Builder clearBeforeModelCallbacks() { - callbacks.removeAll(Callbacks.BeforeModelCallback.class); - ImmutableList.Builder updatedPlugins = ImmutableList.builder(); - for (BasePlugin plugin : plugins.build()) { - if (!plugin.getName().startsWith("BeforeModelCallback_")) { - updatedPlugins.add(plugin); - } - } - plugins = updatedPlugins; - return this; - } - - @CanIgnoreReturnValue - public Builder addAfterModelCallback(Callbacks.AfterModelCallback callback) { - callbacks.put(Callbacks.AfterModelCallback.class, callback); - plugins.add( - new BasePlugin("AfterModelCallback_" + callbackId.getAndIncrement()) { - @Override - public Maybe afterModelCallback( - CallbackContext callbackContext, LlmResponse llmResponse) { - return callback.call(callbackContext, llmResponse); - } - }); - return this; - } - - @CanIgnoreReturnValue - public Builder addAfterModelCallbackSync(Callbacks.AfterModelCallbackSync callback) { - return addAfterModelCallback( - (callbackContext, llmResponse) -> - Maybe.fromOptional(callback.call(callbackContext, llmResponse))); - } - - @CanIgnoreReturnValue - public Builder addBeforeToolCallback(Callbacks.BeforeToolCallback callback) { - callbacks.put(Callbacks.BeforeToolCallback.class, callback); - plugins.add( - new BasePlugin("BeforeToolCallback_" + callbackId.getAndIncrement()) { - @Override - public Maybe> beforeToolCallback( - BaseTool tool, Map toolArgs, ToolContext toolContext) { - return callback.call(toolContext.invocationContext(), tool, toolArgs, toolContext); - } - }); - return this; - } - - @CanIgnoreReturnValue - public Builder addBeforeToolCallbackSync(Callbacks.BeforeToolCallbackSync callback) { - return addBeforeToolCallback( - (invocationContext, tool, toolArgs, toolContext) -> - Maybe.fromOptional(callback.call(invocationContext, tool, toolArgs, toolContext))); - } - - @CanIgnoreReturnValue - public Builder addAfterToolCallback(Callbacks.AfterToolCallback callback) { - callbacks.put(Callbacks.AfterToolCallback.class, callback); - plugins.add( - new BasePlugin("AfterToolCallback_" + callbackId.getAndIncrement()) { - @Override - public Maybe> afterToolCallback( - BaseTool tool, - Map toolArgs, - ToolContext toolContext, - Map result) { - return callback.call( - toolContext.invocationContext(), tool, toolArgs, toolContext, result); - } - }); - return this; - } - - @CanIgnoreReturnValue - public Builder addAfterToolCallbackSync(Callbacks.AfterToolCallbackSync callback) { - return addAfterToolCallback( - (invocationContext, tool, toolArgs, toolContext, result) -> - Maybe.fromOptional( - callback.call(invocationContext, tool, toolArgs, toolContext, result))); - } - - @CanIgnoreReturnValue - public Builder addCallback(BeforeAgentCallbackBase callback) { - if (callback instanceof BeforeAgentCallback beforeAgentCallbackInstance) { - addBeforeAgentCallback(beforeAgentCallbackInstance); - } else if (callback instanceof BeforeAgentCallbackSync beforeAgentCallbackSyncInstance) { - addBeforeAgentCallbackSync(beforeAgentCallbackSyncInstance); - } else { - logger.warn( - "Invalid beforeAgentCallback callback type: %s. Ignoring this callback.", - callback.getClass().getName()); - } - return this; - } - - @CanIgnoreReturnValue - public Builder addCallback(AfterAgentCallbackBase callback) { - if (callback instanceof AfterAgentCallback afterAgentCallbackInstance) { - addAfterAgentCallback(afterAgentCallbackInstance); - } else if (callback instanceof AfterAgentCallbackSync afterAgentCallbackSyncInstance) { - addAfterAgentCallbackSync(afterAgentCallbackSyncInstance); - } else { - logger.warn( - "Invalid afterAgentCallback callback type: %s. Ignoring this callback.", - callback.getClass().getName()); - } - return this; - } - - @CanIgnoreReturnValue - public Builder addCallback(BeforeModelCallbackBase callback) { - if (callback instanceof BeforeModelCallback beforeModelCallbackInstance) { - addBeforeModelCallback(beforeModelCallbackInstance); - } else if (callback instanceof BeforeModelCallbackSync beforeModelCallbackSyncInstance) { - addBeforeModelCallbackSync(beforeModelCallbackSyncInstance); - } else { - logger.warn( - "Invalid beforeModelCallback callback type: %s. Ignoring this callback.", - callback.getClass().getName()); - } - return this; - } - - @CanIgnoreReturnValue - public Builder addCallback(AfterModelCallbackBase callback) { - if (callback instanceof AfterModelCallback afterModelCallbackInstance) { - addAfterModelCallback(afterModelCallbackInstance); - } else if (callback instanceof AfterModelCallbackSync afterModelCallbackSyncInstance) { - addAfterModelCallbackSync(afterModelCallbackSyncInstance); - } else { - logger.warn( - "Invalid afterModelCallback callback type: %s. Ignoring this callback.", - callback.getClass().getName()); - } - return this; - } - - @CanIgnoreReturnValue - public Builder addCallback(BeforeToolCallbackBase callback) { - if (callback instanceof BeforeToolCallback beforeToolCallbackInstance) { - addBeforeToolCallback(beforeToolCallbackInstance); - } else if (callback instanceof BeforeToolCallbackSync beforeToolCallbackSyncInstance) { - addBeforeToolCallbackSync(beforeToolCallbackSyncInstance); - } else { - logger.warn( - "Invalid beforeToolCallback callback type: {}. Ignoring this callback.", - callback.getClass().getName()); - } - return this; - } - - @CanIgnoreReturnValue - public Builder addCallback(AfterToolCallbackBase callback) { - if (callback instanceof AfterToolCallback afterToolCallbackInstance) { - addAfterToolCallback(afterToolCallbackInstance); - } else if (callback instanceof AfterToolCallbackSync afterToolCallbackSyncInstance) { - addAfterToolCallbackSync(afterToolCallbackSyncInstance); - } else { - logger.warn( - "Invalid afterToolCallback callback type: {}. Ignoring this callback.", - callback.getClass().getName()); - } - return this; - } - - public CallbackPlugin build() { - return new CallbackPlugin(plugins.build(), ImmutableListMultimap.copyOf(callbacks)); - } - } -} diff --git a/core/src/main/java/com/google/adk/agents/InvocationContext.java b/core/src/main/java/com/google/adk/agents/InvocationContext.java index 0a8ed416..532bc92f 100644 --- a/core/src/main/java/com/google/adk/agents/InvocationContext.java +++ b/core/src/main/java/com/google/adk/agents/InvocationContext.java @@ -25,7 +25,6 @@ import com.google.adk.plugins.PluginManager; import com.google.adk.sessions.BaseSessionService; import com.google.adk.sessions.Session; -import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import com.google.errorprone.annotations.CanIgnoreReturnValue; import com.google.errorprone.annotations.InlineMe; @@ -45,7 +44,6 @@ public class InvocationContext { private final BaseArtifactService artifactService; private final BaseMemoryService memoryService; private final Plugin pluginManager; - private final Plugin combinedPlugin; private final Optional liveRequestQueue; private final Map activeStreamingTools; private final String invocationId; @@ -75,13 +73,6 @@ protected InvocationContext(Builder builder) { this.endInvocation = builder.endInvocation; this.resumabilityConfig = builder.resumabilityConfig; this.invocationCostManager = builder.invocationCostManager; - this.combinedPlugin = - Optional.ofNullable(builder.agent) - .map(BaseAgent::getPlugin) - .map( - agentPlugin -> - (Plugin) new PluginManager(ImmutableList.of(pluginManager, agentPlugin))) - .orElse(pluginManager); } /** @@ -244,14 +235,6 @@ public Plugin pluginManager() { return pluginManager; } - /** - * Returns a {@link Plugin} that combines agent-specific plugins with framework-level plugins, - * allowing tools from both to be invoked. - */ - public Plugin combinedPlugin() { - return combinedPlugin; - } - /** Returns a map of tool call IDs to active streaming tools for the current invocation. */ public Map activeStreamingTools() { return activeStreamingTools; diff --git a/core/src/main/java/com/google/adk/agents/LlmAgent.java b/core/src/main/java/com/google/adk/agents/LlmAgent.java index fe6763d1..2cfa26b7 100644 --- a/core/src/main/java/com/google/adk/agents/LlmAgent.java +++ b/core/src/main/java/com/google/adk/agents/LlmAgent.java @@ -56,6 +56,7 @@ import com.google.genai.types.Part; import com.google.genai.types.Schema; import io.reactivex.rxjava3.core.Flowable; +import io.reactivex.rxjava3.core.Maybe; import io.reactivex.rxjava3.core.Single; import java.util.ArrayList; import java.util.List; @@ -94,6 +95,10 @@ public enum IncludeContents { private final Optional maxSteps; private final boolean disallowTransferToParent; private final boolean disallowTransferToPeers; + private final Optional> beforeModelCallback; + private final Optional> afterModelCallback; + private final Optional> beforeToolCallback; + private final Optional> afterToolCallback; private final Optional inputSchema; private final Optional outputSchema; private final Optional executor; @@ -108,7 +113,8 @@ protected LlmAgent(Builder builder) { builder.name, builder.description, builder.subAgents, - builder.callbackPluginBuilder.build()); + builder.beforeAgentCallback, + builder.afterAgentCallback); this.model = Optional.ofNullable(builder.model); this.instruction = builder.instruction == null ? new Instruction.Static("") : builder.instruction; @@ -122,6 +128,10 @@ protected LlmAgent(Builder builder) { this.maxSteps = Optional.ofNullable(builder.maxSteps); this.disallowTransferToParent = builder.disallowTransferToParent; this.disallowTransferToPeers = builder.disallowTransferToPeers; + this.beforeModelCallback = Optional.ofNullable(builder.beforeModelCallback); + this.afterModelCallback = Optional.ofNullable(builder.afterModelCallback); + this.beforeToolCallback = Optional.ofNullable(builder.beforeToolCallback); + this.afterToolCallback = Optional.ofNullable(builder.afterToolCallback); this.inputSchema = Optional.ofNullable(builder.inputSchema); this.outputSchema = Optional.ofNullable(builder.outputSchema); this.executor = Optional.ofNullable(builder.executor); @@ -163,6 +173,10 @@ public static class Builder extends BaseAgent.Builder { private Integer maxSteps; private Boolean disallowTransferToParent; private Boolean disallowTransferToPeers; + private ImmutableList beforeModelCallback; + private ImmutableList afterModelCallback; + private ImmutableList beforeToolCallback; + private ImmutableList afterToolCallback; private Schema inputSchema; private Schema outputSchema; private Executor executor; @@ -277,92 +291,206 @@ public Builder disallowTransferToPeers(boolean disallowTransferToPeers) { // (b/476510024): Temporary workaround for ces @CanIgnoreReturnValue public Builder clearBeforeModelCallbacks() { - callbackPluginBuilder.clearBeforeModelCallbacks(); + this.beforeModelCallback = null; return this; } @CanIgnoreReturnValue public Builder beforeModelCallback(BeforeModelCallback beforeModelCallback) { - callbackPluginBuilder.addBeforeModelCallback(beforeModelCallback); + this.beforeModelCallback = ImmutableList.of(beforeModelCallback); return this; } @CanIgnoreReturnValue public Builder beforeModelCallback(List beforeModelCallback) { - beforeModelCallback.forEach(callbackPluginBuilder::addCallback); + if (beforeModelCallback == null) { + this.beforeModelCallback = null; + } else if (beforeModelCallback.isEmpty()) { + this.beforeModelCallback = ImmutableList.of(); + } else { + ImmutableList.Builder builder = ImmutableList.builder(); + for (BeforeModelCallbackBase callback : beforeModelCallback) { + if (callback instanceof BeforeModelCallback beforeModelCallbackInstance) { + builder.add(beforeModelCallbackInstance); + } else if (callback instanceof BeforeModelCallbackSync beforeModelCallbackSyncInstance) { + builder.add( + (BeforeModelCallback) + (callbackContext, llmRequestBuilder) -> + Maybe.fromOptional( + beforeModelCallbackSyncInstance.call( + callbackContext, llmRequestBuilder))); + } else { + logger.warn( + "Invalid beforeModelCallback callback type: %s. Ignoring this callback.", + callback.getClass().getName()); + } + } + this.beforeModelCallback = builder.build(); + } + return this; } @CanIgnoreReturnValue public Builder beforeModelCallbackSync(BeforeModelCallbackSync beforeModelCallbackSync) { - callbackPluginBuilder.addBeforeModelCallbackSync(beforeModelCallbackSync); + this.beforeModelCallback = + ImmutableList.of( + (callbackContext, llmRequestBuilder) -> + Maybe.fromOptional( + beforeModelCallbackSync.call(callbackContext, llmRequestBuilder))); return this; } @CanIgnoreReturnValue public Builder afterModelCallback(AfterModelCallback afterModelCallback) { - callbackPluginBuilder.addAfterModelCallback(afterModelCallback); + this.afterModelCallback = ImmutableList.of(afterModelCallback); return this; } @CanIgnoreReturnValue public Builder afterModelCallback(List afterModelCallback) { - afterModelCallback.forEach(callbackPluginBuilder::addCallback); + if (afterModelCallback == null) { + this.afterModelCallback = null; + } else if (afterModelCallback.isEmpty()) { + this.afterModelCallback = ImmutableList.of(); + } else { + ImmutableList.Builder builder = ImmutableList.builder(); + for (AfterModelCallbackBase callback : afterModelCallback) { + if (callback instanceof AfterModelCallback afterModelCallbackInstance) { + builder.add(afterModelCallbackInstance); + } else if (callback instanceof AfterModelCallbackSync afterModelCallbackSyncInstance) { + builder.add( + (AfterModelCallback) + (callbackContext, llmResponse) -> + Maybe.fromOptional( + afterModelCallbackSyncInstance.call(callbackContext, llmResponse))); + } else { + logger.warn( + "Invalid afterModelCallback callback type: %s. Ignoring this callback.", + callback.getClass().getName()); + } + } + this.afterModelCallback = builder.build(); + } + return this; } @CanIgnoreReturnValue public Builder afterModelCallbackSync(AfterModelCallbackSync afterModelCallbackSync) { - callbackPluginBuilder.addAfterModelCallbackSync(afterModelCallbackSync); + this.afterModelCallback = + ImmutableList.of( + (callbackContext, llmResponse) -> + Maybe.fromOptional(afterModelCallbackSync.call(callbackContext, llmResponse))); return this; } @CanIgnoreReturnValue public Builder beforeAgentCallbackSync(BeforeAgentCallbackSync beforeAgentCallbackSync) { - callbackPluginBuilder.addBeforeAgentCallbackSync(beforeAgentCallbackSync); + this.beforeAgentCallback = + ImmutableList.of( + (callbackContext) -> + Maybe.fromOptional(beforeAgentCallbackSync.call(callbackContext))); return this; } @CanIgnoreReturnValue public Builder afterAgentCallbackSync(AfterAgentCallbackSync afterAgentCallbackSync) { - callbackPluginBuilder.addAfterAgentCallbackSync(afterAgentCallbackSync); + this.afterAgentCallback = + ImmutableList.of( + (callbackContext) -> + Maybe.fromOptional(afterAgentCallbackSync.call(callbackContext))); return this; } @CanIgnoreReturnValue public Builder beforeToolCallback(BeforeToolCallback beforeToolCallback) { - callbackPluginBuilder.addBeforeToolCallback(beforeToolCallback); + this.beforeToolCallback = ImmutableList.of(beforeToolCallback); return this; } @CanIgnoreReturnValue public Builder beforeToolCallback( @Nullable List beforeToolCallbacks) { - beforeToolCallbacks.forEach(callbackPluginBuilder::addCallback); + if (beforeToolCallbacks == null) { + this.beforeToolCallback = null; + } else if (beforeToolCallbacks.isEmpty()) { + this.beforeToolCallback = ImmutableList.of(); + } else { + ImmutableList.Builder builder = ImmutableList.builder(); + for (BeforeToolCallbackBase callback : beforeToolCallbacks) { + if (callback instanceof BeforeToolCallback beforeToolCallbackInstance) { + builder.add(beforeToolCallbackInstance); + } else if (callback instanceof BeforeToolCallbackSync beforeToolCallbackSyncInstance) { + builder.add( + (invocationContext, baseTool, input, toolContext) -> + Maybe.fromOptional( + beforeToolCallbackSyncInstance.call( + invocationContext, baseTool, input, toolContext))); + } else { + logger.warn( + "Invalid beforeToolCallback callback type: {}. Ignoring this callback.", + callback.getClass().getName()); + } + } + this.beforeToolCallback = builder.build(); + } return this; } @CanIgnoreReturnValue public Builder beforeToolCallbackSync(BeforeToolCallbackSync beforeToolCallbackSync) { - callbackPluginBuilder.addBeforeToolCallbackSync(beforeToolCallbackSync); + this.beforeToolCallback = + ImmutableList.of( + (invocationContext, baseTool, input, toolContext) -> + Maybe.fromOptional( + beforeToolCallbackSync.call( + invocationContext, baseTool, input, toolContext))); return this; } @CanIgnoreReturnValue public Builder afterToolCallback(AfterToolCallback afterToolCallback) { - callbackPluginBuilder.addAfterToolCallback(afterToolCallback); + this.afterToolCallback = ImmutableList.of(afterToolCallback); return this; } @CanIgnoreReturnValue public Builder afterToolCallback(@Nullable List afterToolCallbacks) { - afterToolCallbacks.forEach(callbackPluginBuilder::addCallback); + if (afterToolCallbacks == null) { + this.afterToolCallback = null; + } else if (afterToolCallbacks.isEmpty()) { + this.afterToolCallback = ImmutableList.of(); + } else { + ImmutableList.Builder builder = ImmutableList.builder(); + for (AfterToolCallbackBase callback : afterToolCallbacks) { + if (callback instanceof AfterToolCallback afterToolCallbackInstance) { + builder.add(afterToolCallbackInstance); + } else if (callback instanceof AfterToolCallbackSync afterToolCallbackSyncInstance) { + builder.add( + (invocationContext, baseTool, input, toolContext, response) -> + Maybe.fromOptional( + afterToolCallbackSyncInstance.call( + invocationContext, baseTool, input, toolContext, response))); + } else { + logger.warn( + "Invalid afterToolCallback callback type: {}. Ignoring this callback.", + callback.getClass().getName()); + } + } + this.afterToolCallback = builder.build(); + } return this; } @CanIgnoreReturnValue public Builder afterToolCallbackSync(AfterToolCallbackSync afterToolCallbackSync) { - callbackPluginBuilder.addAfterToolCallbackSync(afterToolCallbackSync); + this.afterToolCallback = + ImmutableList.of( + (invocationContext, baseTool, input, toolContext, response) -> + Maybe.fromOptional( + afterToolCallbackSync.call( + invocationContext, baseTool, input, toolContext, response))); return this; } @@ -636,19 +764,19 @@ public boolean disallowTransferToPeers() { } public Optional> beforeModelCallback() { - return Optional.of(callbackPlugin.getBeforeModelCallback()); + return beforeModelCallback; } public Optional> afterModelCallback() { - return Optional.of(callbackPlugin.getAfterModelCallback()); + return afterModelCallback; } public Optional> beforeToolCallback() { - return Optional.of(callbackPlugin.getBeforeToolCallback()); + return beforeToolCallback; } public Optional> afterToolCallback() { - return Optional.of(callbackPlugin.getAfterToolCallback()); + return afterToolCallback; } public Optional inputSchema() { @@ -709,8 +837,8 @@ private Model resolveModelInternal() { } BaseAgent current = this.parentAgent(); while (current != null) { - if (current instanceof LlmAgent llmAgent) { - return llmAgent.resolvedModel(); + if (current instanceof LlmAgent) { + return ((LlmAgent) current).resolvedModel(); } current = current.parentAgent(); } diff --git a/core/src/main/java/com/google/adk/agents/LoopAgent.java b/core/src/main/java/com/google/adk/agents/LoopAgent.java index 921ef368..d9d049f8 100644 --- a/core/src/main/java/com/google/adk/agents/LoopAgent.java +++ b/core/src/main/java/com/google/adk/agents/LoopAgent.java @@ -46,13 +46,16 @@ public class LoopAgent extends BaseAgent { * @param beforeAgentCallback Optional callback before the agent runs. * @param afterAgentCallback Optional callback after the agent runs. */ - private LoopAgent(Builder builder) { - super( - builder.name, - builder.description, - builder.subAgents, - builder.callbackPluginBuilder.build()); - this.maxIterations = builder.maxIterations; + private LoopAgent( + String name, + String description, + List subAgents, + Optional maxIterations, + List beforeAgentCallback, + List afterAgentCallback) { + + super(name, description, subAgents, beforeAgentCallback, afterAgentCallback); + this.maxIterations = maxIterations; } /** Builder for {@link LoopAgent}. */ @@ -73,7 +76,9 @@ public Builder maxIterations(Optional maxIterations) { @Override public LoopAgent build() { - return new LoopAgent(this); + // TODO(b/410859954): Add validation for required fields like name. + return new LoopAgent( + name, description, subAgents, maxIterations, beforeAgentCallback, afterAgentCallback); } } diff --git a/core/src/main/java/com/google/adk/agents/ParallelAgent.java b/core/src/main/java/com/google/adk/agents/ParallelAgent.java index 583bfffc..f30d951a 100644 --- a/core/src/main/java/com/google/adk/agents/ParallelAgent.java +++ b/core/src/main/java/com/google/adk/agents/ParallelAgent.java @@ -45,12 +45,14 @@ public class ParallelAgent extends BaseAgent { * @param beforeAgentCallback Optional callback before the agent runs. * @param afterAgentCallback Optional callback after the agent runs. */ - private ParallelAgent(Builder builder) { - super( - builder.name, - builder.description, - builder.subAgents, - builder.callbackPluginBuilder.build()); + private ParallelAgent( + String name, + String description, + List subAgents, + List beforeAgentCallback, + List afterAgentCallback) { + + super(name, description, subAgents, beforeAgentCallback, afterAgentCallback); } /** Builder for {@link ParallelAgent}. */ @@ -58,7 +60,8 @@ public static class Builder extends BaseAgent.Builder { @Override public ParallelAgent build() { - return new ParallelAgent(this); + return new ParallelAgent( + name, description, subAgents, beforeAgentCallback, afterAgentCallback); } } diff --git a/core/src/main/java/com/google/adk/agents/ReadonlyContext.java b/core/src/main/java/com/google/adk/agents/ReadonlyContext.java index dc7480f5..7d3a5acb 100644 --- a/core/src/main/java/com/google/adk/agents/ReadonlyContext.java +++ b/core/src/main/java/com/google/adk/agents/ReadonlyContext.java @@ -34,11 +34,6 @@ public ReadonlyContext(InvocationContext invocationContext) { this.invocationContext = invocationContext; } - /** Returns the invocation context. */ - public InvocationContext invocationContext() { - return invocationContext; - } - /** Returns the user content that initiated this invocation. */ public Optional userContent() { return invocationContext.userContent(); diff --git a/core/src/main/java/com/google/adk/agents/SequentialAgent.java b/core/src/main/java/com/google/adk/agents/SequentialAgent.java index aa4b76fb..b0b45a0e 100644 --- a/core/src/main/java/com/google/adk/agents/SequentialAgent.java +++ b/core/src/main/java/com/google/adk/agents/SequentialAgent.java @@ -18,6 +18,7 @@ import com.google.adk.agents.ConfigAgentUtils.ConfigurationException; import com.google.adk.events.Event; import io.reactivex.rxjava3.core.Flowable; +import java.util.List; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -35,12 +36,14 @@ public class SequentialAgent extends BaseAgent { * @param beforeAgentCallback Optional callback before the agent runs. * @param afterAgentCallback Optional callback after the agent runs. */ - private SequentialAgent(Builder builder) { - super( - builder.name, - builder.description, - builder.subAgents, - builder.callbackPluginBuilder.build()); + private SequentialAgent( + String name, + String description, + List subAgents, + List beforeAgentCallback, + List afterAgentCallback) { + + super(name, description, subAgents, beforeAgentCallback, afterAgentCallback); } /** Builder for {@link SequentialAgent}. */ @@ -48,7 +51,9 @@ public static class Builder extends BaseAgent.Builder { @Override public SequentialAgent build() { - return new SequentialAgent(this); + // TODO(b/410859954): Add validation for required fields like name. + return new SequentialAgent( + name, description, subAgents, beforeAgentCallback, afterAgentCallback); } } diff --git a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java index 307b159f..5e6331b7 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java @@ -199,7 +199,7 @@ private Flowable callLlm( .onErrorResumeNext( exception -> context - .combinedPlugin() + .pluginManager() .onModelErrorCallback( new CallbackContext( context, eventForCallbackUsage.actions()), @@ -243,9 +243,27 @@ private Single> handleBeforeModelCallback( Event callbackEvent = modelResponseEvent.toBuilder().build(); CallbackContext callbackContext = new CallbackContext(context, callbackEvent.actions()); - return context - .combinedPlugin() - .beforeModelCallback(callbackContext, llmRequestBuilder) + Maybe pluginResult = + context.pluginManager().beforeModelCallback(callbackContext, llmRequestBuilder); + + LlmAgent agent = (LlmAgent) context.agent(); + + Optional> callbacksOpt = agent.beforeModelCallback(); + if (callbacksOpt.isEmpty() || callbacksOpt.get().isEmpty()) { + return pluginResult.map(Optional::of).defaultIfEmpty(Optional.empty()); + } + + List callbacks = callbacksOpt.get(); + + Maybe callbackResult = + Maybe.defer( + () -> + Flowable.fromIterable(callbacks) + .concatMapMaybe(callback -> callback.call(callbackContext, llmRequestBuilder)) + .firstElement()); + + return pluginResult + .switchIfEmpty(callbackResult) .map(Optional::of) .defaultIfEmpty(Optional.empty()); } @@ -261,10 +279,24 @@ private Single handleAfterModelCallback( Event callbackEvent = modelResponseEvent.toBuilder().build(); CallbackContext callbackContext = new CallbackContext(context, callbackEvent.actions()); - return context - .combinedPlugin() - .afterModelCallback(callbackContext, llmResponse) - .defaultIfEmpty(llmResponse); + Maybe pluginResult = + context.pluginManager().afterModelCallback(callbackContext, llmResponse); + + LlmAgent agent = (LlmAgent) context.agent(); + Optional> callbacksOpt = agent.afterModelCallback(); + + if (callbacksOpt.isEmpty() || callbacksOpt.get().isEmpty()) { + return pluginResult.defaultIfEmpty(llmResponse); + } + + Maybe callbackResult = + Maybe.defer( + () -> + Flowable.fromIterable(callbacksOpt.get()) + .concatMapMaybe(callback -> callback.call(callbackContext, llmResponse)) + .firstElement()); + + return pluginResult.switchIfEmpty(callbackResult).defaultIfEmpty(llmResponse); } /** diff --git a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java index 9a81da8e..8536e470 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java @@ -22,7 +22,10 @@ import com.google.adk.Telemetry; import com.google.adk.agents.ActiveStreamingTool; +import com.google.adk.agents.Callbacks.AfterToolCallback; +import com.google.adk.agents.Callbacks.BeforeToolCallback; import com.google.adk.agents.InvocationContext; +import com.google.adk.agents.LlmAgent; import com.google.adk.agents.RunConfig.ToolExecutionMode; import com.google.adk.events.Event; import com.google.adk.events.EventActions; @@ -385,7 +388,7 @@ private static Maybe postProcessFunctionResult( .onErrorResumeNext( t -> invocationContext - .combinedPlugin() + .pluginManager() .onToolErrorCallback(tool, functionArgs, toolContext, t) .map(isLive ? Optional::ofNullable : Optional::of) .switchIfEmpty(Single.error(t))) @@ -454,7 +457,30 @@ private static Maybe> maybeInvokeBeforeToolCall( BaseTool tool, Map functionArgs, ToolContext toolContext) { - return invocationContext.combinedPlugin().beforeToolCallback(tool, functionArgs, toolContext); + if (invocationContext.agent() instanceof LlmAgent) { + LlmAgent agent = (LlmAgent) invocationContext.agent(); + + Maybe> pluginResult = + invocationContext.pluginManager().beforeToolCallback(tool, functionArgs, toolContext); + + Optional> callbacksOpt = agent.beforeToolCallback(); + if (callbacksOpt.isEmpty() || callbacksOpt.get().isEmpty()) { + return pluginResult; + } + List callbacks = callbacksOpt.get(); + + Maybe> callbackResult = + Maybe.defer( + () -> + Flowable.fromIterable(callbacks) + .concatMapMaybe( + callback -> + callback.call(invocationContext, tool, functionArgs, toolContext)) + .firstElement()); + + return pluginResult.switchIfEmpty(callbackResult); + } + return Maybe.empty(); } private static Maybe> maybeInvokeAfterToolCall( @@ -463,9 +489,37 @@ private static Maybe> maybeInvokeAfterToolCall( Map functionArgs, ToolContext toolContext, Map functionResult) { - return invocationContext - .combinedPlugin() - .afterToolCallback(tool, functionArgs, toolContext, functionResult); + if (invocationContext.agent() instanceof LlmAgent) { + LlmAgent agent = (LlmAgent) invocationContext.agent(); + + Maybe> pluginResult = + invocationContext + .pluginManager() + .afterToolCallback(tool, functionArgs, toolContext, functionResult); + + Optional> callbacksOpt = agent.afterToolCallback(); + if (callbacksOpt.isEmpty() || callbacksOpt.get().isEmpty()) { + return pluginResult; + } + List callbacks = callbacksOpt.get(); + + Maybe> callbackResult = + Maybe.defer( + () -> + Flowable.fromIterable(callbacks) + .concatMapMaybe( + callback -> + callback.call( + invocationContext, + tool, + functionArgs, + toolContext, + functionResult)) + .firstElement()); + + return pluginResult.switchIfEmpty(callbackResult); + } + return Maybe.empty(); } private static Maybe> callTool( diff --git a/core/src/main/java/com/google/adk/runner/Runner.java b/core/src/main/java/com/google/adk/runner/Runner.java index 155511bb..e047d1e6 100644 --- a/core/src/main/java/com/google/adk/runner/Runner.java +++ b/core/src/main/java/com/google/adk/runner/Runner.java @@ -529,7 +529,7 @@ private Flowable runAgentWithFreshSession( // TODO: remove this hack after deprecating runAsync with Session. copySessionStates(updatedSession, session); return contextWithUpdatedSession - .combinedPlugin() + .pluginManager() .onEventCallback(contextWithUpdatedSession, registeredEvent) .defaultIfEmpty(registeredEvent); }) diff --git a/core/src/test/java/com/google/adk/agents/CallbackPluginTest.java b/core/src/test/java/com/google/adk/agents/CallbackPluginTest.java deleted file mode 100644 index 361c8619..00000000 --- a/core/src/test/java/com/google/adk/agents/CallbackPluginTest.java +++ /dev/null @@ -1,499 +0,0 @@ -/* - * Copyright 2026 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.google.adk.agents; - -import static com.google.adk.testing.TestUtils.createInvocationContext; -import static com.google.common.truth.Truth.assertThat; - -import com.google.adk.agents.Callbacks.AfterAgentCallback; -import com.google.adk.agents.Callbacks.AfterAgentCallbackSync; -import com.google.adk.agents.Callbacks.AfterModelCallback; -import com.google.adk.agents.Callbacks.AfterModelCallbackSync; -import com.google.adk.agents.Callbacks.AfterToolCallback; -import com.google.adk.agents.Callbacks.AfterToolCallbackSync; -import com.google.adk.agents.Callbacks.BeforeAgentCallback; -import com.google.adk.agents.Callbacks.BeforeAgentCallbackSync; -import com.google.adk.agents.Callbacks.BeforeModelCallback; -import com.google.adk.agents.Callbacks.BeforeModelCallbackSync; -import com.google.adk.agents.Callbacks.BeforeToolCallback; -import com.google.adk.agents.Callbacks.BeforeToolCallbackSync; -import com.google.adk.events.EventActions; -import com.google.adk.models.LlmRequest; -import com.google.adk.models.LlmResponse; -import com.google.adk.testing.TestCallback; -import com.google.adk.tools.BaseTool; -import com.google.adk.tools.ToolContext; -import com.google.common.collect.ImmutableMap; -import com.google.genai.types.Content; -import com.google.genai.types.Part; -import io.reactivex.rxjava3.core.Maybe; -import java.util.Map; -import java.util.Optional; -import java.util.concurrent.ConcurrentHashMap; -import org.junit.Before; -import org.junit.Rule; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; -import org.mockito.Mock; -import org.mockito.junit.MockitoJUnit; -import org.mockito.junit.MockitoRule; - -@RunWith(JUnit4.class) -public final class CallbackPluginTest { - - @Rule public final MockitoRule mockito = MockitoJUnit.rule(); - @Mock private BaseAgent agent; - @Mock private BaseTool tool; - @Mock private ToolContext toolContext; - private InvocationContext invocationContext; - private CallbackContext callbackContext; - - @Before - public void setUp() { - invocationContext = createInvocationContext(agent); - callbackContext = - new CallbackContext( - invocationContext, - EventActions.builder().stateDelta(new ConcurrentHashMap<>()).build()); - } - - @Test - public void build_empty_successful() { - CallbackPlugin plugin = CallbackPlugin.builder().build(); - assertThat(plugin.getName()).isEqualTo("CallbackPlugin"); - assertThat(plugin.getBeforeAgentCallback()).isEmpty(); - assertThat(plugin.getAfterAgentCallback()).isEmpty(); - assertThat(plugin.getBeforeModelCallback()).isEmpty(); - assertThat(plugin.getAfterModelCallback()).isEmpty(); - assertThat(plugin.getBeforeToolCallback()).isEmpty(); - assertThat(plugin.getAfterToolCallback()).isEmpty(); - } - - @Test - public void addBeforeAgentCallback_isReturnedAndInvoked() { - Content expectedContent = Content.fromParts(Part.fromText("test")); - var testCallback = TestCallback.returning(expectedContent); - BeforeAgentCallback callback = testCallback.asBeforeAgentCallback(); - - CallbackPlugin plugin = CallbackPlugin.builder().addBeforeAgentCallback(callback).build(); - - assertThat(plugin.getBeforeAgentCallback()).containsExactly(callback); - - Content result = plugin.beforeAgentCallback(agent, callbackContext).blockingGet(); - assertThat(testCallback.wasCalled()).isTrue(); - assertThat(result).isEqualTo(expectedContent); - } - - @Test - public void addBeforeAgentCallbackSync_isReturnedAndInvoked() { - Content expectedContent = Content.fromParts(Part.fromText("test")); - var testCallback = TestCallback.returning(expectedContent); - CallbackPlugin plugin = - CallbackPlugin.builder() - .addBeforeAgentCallbackSync(testCallback.asBeforeAgentCallbackSync()) - .build(); - - assertThat(plugin.getBeforeAgentCallback()).hasSize(1); - - Content result = plugin.beforeAgentCallback(agent, callbackContext).blockingGet(); - assertThat(testCallback.wasCalled()).isTrue(); - assertThat(result).isEqualTo(expectedContent); - } - - @Test - public void addAfterAgentCallback_isReturnedAndInvoked() { - Content expectedContent = Content.fromParts(Part.fromText("test")); - var testCallback = TestCallback.returning(expectedContent); - AfterAgentCallback callback = testCallback.asAfterAgentCallback(); - - CallbackPlugin plugin = CallbackPlugin.builder().addAfterAgentCallback(callback).build(); - - assertThat(plugin.getAfterAgentCallback()).containsExactly(callback); - - Content result = plugin.afterAgentCallback(agent, callbackContext).blockingGet(); - assertThat(testCallback.wasCalled()).isTrue(); - assertThat(result).isEqualTo(expectedContent); - } - - @Test - public void addAfterAgentCallbackSync_isReturnedAndInvoked() { - Content expectedContent = Content.fromParts(Part.fromText("test")); - var testCallback = TestCallback.returning(expectedContent); - CallbackPlugin plugin = - CallbackPlugin.builder() - .addAfterAgentCallbackSync(testCallback.asAfterAgentCallbackSync()) - .build(); - - assertThat(plugin.getAfterAgentCallback()).hasSize(1); - - Content result = plugin.afterAgentCallback(agent, callbackContext).blockingGet(); - assertThat(testCallback.wasCalled()).isTrue(); - assertThat(result).isEqualTo(expectedContent); - } - - @Test - public void addBeforeModelCallback_isReturnedAndInvoked() { - LlmResponse expectedResponse = LlmResponse.builder().build(); - var testCallback = TestCallback.returning(expectedResponse); - BeforeModelCallback callback = testCallback.asBeforeModelCallback(); - - CallbackPlugin plugin = CallbackPlugin.builder().addBeforeModelCallback(callback).build(); - - assertThat(plugin.getBeforeModelCallback()).containsExactly(callback); - - LlmResponse result = - plugin.beforeModelCallback(callbackContext, LlmRequest.builder()).blockingGet(); - assertThat(testCallback.wasCalled()).isTrue(); - assertThat(result).isEqualTo(expectedResponse); - } - - @Test - public void addBeforeModelCallbackSync_isReturnedAndInvoked() { - LlmResponse expectedResponse = LlmResponse.builder().build(); - var testCallback = TestCallback.returning(expectedResponse); - CallbackPlugin plugin = - CallbackPlugin.builder() - .addBeforeModelCallbackSync(testCallback.asBeforeModelCallbackSync()) - .build(); - - assertThat(plugin.getBeforeModelCallback()).hasSize(1); - - LlmResponse result = - plugin.beforeModelCallback(callbackContext, LlmRequest.builder()).blockingGet(); - assertThat(testCallback.wasCalled()).isTrue(); - assertThat(result).isEqualTo(expectedResponse); - } - - @Test - public void addAfterModelCallback_isReturnedAndInvoked() { - LlmResponse initialResponse = LlmResponse.builder().build(); - LlmResponse expectedResponse = - LlmResponse.builder().content(Content.fromParts(Part.fromText("test"))).build(); - var testCallback = TestCallback.returning(expectedResponse); - AfterModelCallback callback = testCallback.asAfterModelCallback(); - - CallbackPlugin plugin = CallbackPlugin.builder().addAfterModelCallback(callback).build(); - - assertThat(plugin.getAfterModelCallback()).containsExactly(callback); - - LlmResponse result = plugin.afterModelCallback(callbackContext, initialResponse).blockingGet(); - assertThat(testCallback.wasCalled()).isTrue(); - assertThat(result).isEqualTo(expectedResponse); - } - - @Test - public void addAfterModelCallbackSync_isReturnedAndInvoked() { - LlmResponse initialResponse = LlmResponse.builder().build(); - LlmResponse expectedResponse = - LlmResponse.builder().content(Content.fromParts(Part.fromText("test"))).build(); - var testCallback = TestCallback.returning(expectedResponse); - AfterModelCallbackSync callback = testCallback.asAfterModelCallbackSync(); - - CallbackPlugin plugin = CallbackPlugin.builder().addAfterModelCallbackSync(callback).build(); - - assertThat(plugin.getAfterModelCallback()).hasSize(1); - - LlmResponse result = plugin.afterModelCallback(callbackContext, initialResponse).blockingGet(); - assertThat(testCallback.wasCalled()).isTrue(); - assertThat(result).isEqualTo(expectedResponse); - } - - @Test - public void addBeforeToolCallback_isReturnedAndInvoked() { - ImmutableMap expectedResult = ImmutableMap.of("key", "value"); - var testCallback = TestCallback.returning(expectedResult); - BeforeToolCallback callback = testCallback.asBeforeToolCallback(); - - CallbackPlugin plugin = CallbackPlugin.builder().addBeforeToolCallback(callback).build(); - - assertThat(plugin.getBeforeToolCallback()).containsExactly(callback); - - Map result = - plugin.beforeToolCallback(tool, ImmutableMap.of(), toolContext).blockingGet(); - assertThat(testCallback.wasCalled()).isTrue(); - assertThat(result).isEqualTo(expectedResult); - } - - @Test - public void addBeforeToolCallbackSync_isReturnedAndInvoked() { - ImmutableMap expectedResult = ImmutableMap.of("key", "value"); - var testCallback = TestCallback.returning(expectedResult); - CallbackPlugin plugin = - CallbackPlugin.builder() - .addBeforeToolCallbackSync(testCallback.asBeforeToolCallbackSync()) - .build(); - - assertThat(plugin.getBeforeToolCallback()).hasSize(1); - - Map result = - plugin.beforeToolCallback(tool, ImmutableMap.of(), toolContext).blockingGet(); - assertThat(testCallback.wasCalled()).isTrue(); - assertThat(result).isEqualTo(expectedResult); - } - - @Test - public void addAfterToolCallback_isReturnedAndInvoked() { - ImmutableMap initialResult = ImmutableMap.of("initial", "result"); - ImmutableMap expectedResult = ImmutableMap.of("key", "value"); - var testCallback = TestCallback.returning(expectedResult); - AfterToolCallback callback = testCallback.asAfterToolCallback(); - - CallbackPlugin plugin = CallbackPlugin.builder().addAfterToolCallback(callback).build(); - - assertThat(plugin.getAfterToolCallback()).containsExactly(callback); - - Map result = - plugin.afterToolCallback(tool, ImmutableMap.of(), toolContext, initialResult).blockingGet(); - assertThat(testCallback.wasCalled()).isTrue(); - assertThat(result).isEqualTo(expectedResult); - } - - @Test - public void addAfterToolCallbackSync_isReturnedAndInvoked() { - ImmutableMap initialResult = ImmutableMap.of("initial", "result"); - ImmutableMap expectedResult = ImmutableMap.of("key", "value"); - var testCallback = TestCallback.returning(expectedResult); - AfterToolCallbackSync callback = testCallback.asAfterToolCallbackSync(); - - CallbackPlugin plugin = CallbackPlugin.builder().addAfterToolCallbackSync(callback).build(); - - assertThat(plugin.getAfterToolCallback()).hasSize(1); - - Map result = - plugin.afterToolCallback(tool, ImmutableMap.of(), toolContext, initialResult).blockingGet(); - assertThat(testCallback.wasCalled()).isTrue(); - assertThat(result).isEqualTo(expectedResult); - } - - @Test - public void addCallback_beforeAgentCallback() { - BeforeAgentCallback callback = ctx -> Maybe.empty(); - CallbackPlugin plugin = CallbackPlugin.builder().addCallback(callback).build(); - assertThat(plugin.getBeforeAgentCallback()).containsExactly(callback); - } - - @Test - public void addCallback_beforeAgentCallbackSync() { - BeforeAgentCallbackSync callback = ctx -> Optional.empty(); - CallbackPlugin plugin = CallbackPlugin.builder().addCallback(callback).build(); - assertThat(plugin.getBeforeAgentCallback()).hasSize(1); - } - - @Test - public void addCallback_afterAgentCallback() { - AfterAgentCallback callback = ctx -> Maybe.empty(); - CallbackPlugin plugin = CallbackPlugin.builder().addCallback(callback).build(); - assertThat(plugin.getAfterAgentCallback()).containsExactly(callback); - } - - @Test - public void addCallback_afterAgentCallbackSync() { - AfterAgentCallbackSync callback = ctx -> Optional.empty(); - CallbackPlugin plugin = CallbackPlugin.builder().addCallback(callback).build(); - assertThat(plugin.getAfterAgentCallback()).hasSize(1); - } - - @Test - public void addCallback_beforeModelCallback() { - BeforeModelCallback callback = (ctx, req) -> Maybe.empty(); - CallbackPlugin plugin = CallbackPlugin.builder().addCallback(callback).build(); - assertThat(plugin.getBeforeModelCallback()).containsExactly(callback); - } - - @Test - public void addCallback_beforeModelCallbackSync() { - BeforeModelCallbackSync callback = (ctx, req) -> Optional.empty(); - CallbackPlugin plugin = CallbackPlugin.builder().addCallback(callback).build(); - assertThat(plugin.getBeforeModelCallback()).hasSize(1); - } - - @Test - public void addCallback_afterModelCallback() { - AfterModelCallback callback = (ctx, res) -> Maybe.empty(); - CallbackPlugin plugin = CallbackPlugin.builder().addCallback(callback).build(); - assertThat(plugin.getAfterModelCallback()).containsExactly(callback); - } - - @Test - public void addCallback_afterModelCallbackSync() { - AfterModelCallbackSync callback = (ctx, res) -> Optional.empty(); - CallbackPlugin plugin = CallbackPlugin.builder().addCallback(callback).build(); - assertThat(plugin.getAfterModelCallback()).hasSize(1); - } - - @Test - public void addCallback_beforeToolCallback() { - BeforeToolCallback callback = (invCtx, tool, toolArgs, toolCtx) -> Maybe.empty(); - CallbackPlugin plugin = CallbackPlugin.builder().addCallback(callback).build(); - assertThat(plugin.getBeforeToolCallback()).containsExactly(callback); - } - - @Test - public void addCallback_beforeToolCallbackSync() { - BeforeToolCallbackSync callback = (invCtx, tool, toolArgs, toolCtx) -> Optional.empty(); - CallbackPlugin plugin = CallbackPlugin.builder().addCallback(callback).build(); - assertThat(plugin.getBeforeToolCallback()).hasSize(1); - } - - @Test - public void addCallback_afterToolCallback() { - AfterToolCallback callback = (invCtx, tool, toolArgs, toolCtx, res) -> Maybe.empty(); - CallbackPlugin plugin = CallbackPlugin.builder().addCallback(callback).build(); - assertThat(plugin.getAfterToolCallback()).containsExactly(callback); - } - - @Test - public void addCallback_afterToolCallbackSync() { - AfterToolCallbackSync callback = (invCtx, tool, toolArgs, toolCtx, res) -> Optional.empty(); - CallbackPlugin plugin = CallbackPlugin.builder().addCallback(callback).build(); - assertThat(plugin.getAfterToolCallback()).hasSize(1); - } - - @Test - public void addMultipleBeforeModelCallbacks_invokedInOrder() { - LlmResponse expectedResponse = LlmResponse.builder().build(); - var testCallback1 = TestCallback.returningEmpty(); - var testCallback2 = TestCallback.returning(expectedResponse); - BeforeModelCallback callback1 = testCallback1.asBeforeModelCallback(); - BeforeModelCallback callback2 = testCallback2.asBeforeModelCallback(); - - CallbackPlugin plugin = - CallbackPlugin.builder() - .addBeforeModelCallback(callback1) - .addBeforeModelCallback(callback2) - .build(); - - assertThat(plugin.getBeforeModelCallback()).containsExactly(callback1, callback2).inOrder(); - - LlmResponse result = - plugin.beforeModelCallback(callbackContext, LlmRequest.builder()).blockingGet(); - assertThat(testCallback1.wasCalled()).isTrue(); - assertThat(testCallback2.wasCalled()).isTrue(); - assertThat(result).isEqualTo(expectedResponse); - } - - @Test - public void addMultipleBeforeModelCallbacks_shortCircuit() { - LlmResponse expectedResponse = LlmResponse.builder().build(); - var testCallback1 = TestCallback.returning(expectedResponse); - var testCallback2 = TestCallback.returningEmpty(); - BeforeModelCallback callback1 = testCallback1.asBeforeModelCallback(); - BeforeModelCallback callback2 = testCallback2.asBeforeModelCallback(); - - CallbackPlugin plugin = - CallbackPlugin.builder() - .addBeforeModelCallback(callback1) - .addBeforeModelCallback(callback2) - .build(); - - assertThat(plugin.getBeforeModelCallback()).containsExactly(callback1, callback2).inOrder(); - - LlmResponse result = - plugin.beforeModelCallback(callbackContext, LlmRequest.builder()).blockingGet(); - assertThat(testCallback1.wasCalled()).isTrue(); - assertThat(testCallback2.wasCalled()).isFalse(); - assertThat(result).isEqualTo(expectedResponse); - } - - @Test - public void addMultipleAfterModelCallbacks_shortCircuit() { - LlmResponse initialResponse = LlmResponse.builder().build(); - LlmResponse expectedResponse = - LlmResponse.builder().content(Content.fromParts(Part.fromText("response"))).build(); - var testCallback1 = TestCallback.returning(expectedResponse); - var testCallback2 = TestCallback.returningEmpty(); - AfterModelCallback callback1 = testCallback1.asAfterModelCallback(); - AfterModelCallback callback2 = testCallback2.asAfterModelCallback(); - CallbackPlugin plugin = - CallbackPlugin.builder() - .addAfterModelCallback(callback1) - .addAfterModelCallback(callback2) - .build(); - - assertThat(plugin.getAfterModelCallback()).containsExactly(callback1, callback2).inOrder(); - LlmResponse result = plugin.afterModelCallback(callbackContext, initialResponse).blockingGet(); - assertThat(testCallback1.wasCalled()).isTrue(); - assertThat(testCallback2.wasCalled()).isFalse(); - assertThat(result).isEqualTo(expectedResponse); - } - - @Test - public void addMultipleAfterModelCallbacks_invokedInOrder() { - LlmResponse initialResponse = LlmResponse.builder().build(); - LlmResponse expectedResponse = - LlmResponse.builder().content(Content.fromParts(Part.fromText("second"))).build(); - var testCallback1 = TestCallback.returningEmpty(); - var testCallback2 = TestCallback.returning(expectedResponse); - AfterModelCallback callback1 = testCallback1.asAfterModelCallback(); - AfterModelCallback callback2 = testCallback2.asAfterModelCallback(); - - CallbackPlugin plugin = - CallbackPlugin.builder() - .addAfterModelCallback(callback1) - .addAfterModelCallback(callback2) - .build(); - - assertThat(plugin.getAfterModelCallback()).containsExactly(callback1, callback2).inOrder(); - - LlmResponse result = plugin.afterModelCallback(callbackContext, initialResponse).blockingGet(); - assertThat(testCallback1.wasCalled()).isTrue(); - assertThat(testCallback2.wasCalled()).isTrue(); - assertThat(result).isEqualTo(expectedResponse); - } - - @Test - public void addMultipleBeforeModelCallbacks_bothEmpty_returnsEmpty() { - var testCallback1 = TestCallback.returningEmpty(); - var testCallback2 = TestCallback.returningEmpty(); - BeforeModelCallback callback1 = testCallback1.asBeforeModelCallback(); - BeforeModelCallback callback2 = testCallback2.asBeforeModelCallback(); - - CallbackPlugin plugin = - CallbackPlugin.builder() - .addBeforeModelCallback(callback1) - .addBeforeModelCallback(callback2) - .build(); - - assertThat(plugin.getBeforeModelCallback()).containsExactly(callback1, callback2).inOrder(); - - LlmResponse result = - plugin.beforeModelCallback(callbackContext, LlmRequest.builder()).blockingGet(); - assertThat(testCallback1.wasCalled()).isTrue(); - assertThat(testCallback2.wasCalled()).isTrue(); - assertThat(result).isNull(); - } - - @Test - public void addMultipleAfterModelCallbacks_bothEmpty_returnsEmpty() { - LlmResponse initialResponse = LlmResponse.builder().build(); - var testCallback1 = TestCallback.returningEmpty(); - var testCallback2 = TestCallback.returningEmpty(); - AfterModelCallback callback1 = testCallback1.asAfterModelCallback(); - AfterModelCallback callback2 = testCallback2.asAfterModelCallback(); - CallbackPlugin plugin = - CallbackPlugin.builder() - .addAfterModelCallback(callback1) - .addAfterModelCallback(callback2) - .build(); - - assertThat(plugin.getAfterModelCallback()).containsExactly(callback1, callback2).inOrder(); - LlmResponse result = plugin.afterModelCallback(callbackContext, initialResponse).blockingGet(); - assertThat(testCallback1.wasCalled()).isTrue(); - assertThat(testCallback2.wasCalled()).isTrue(); - assertThat(result).isNull(); - } -}