Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 80 additions & 72 deletions core/src/main/java/com/google/adk/agents/CallbackPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
import io.reactivex.rxjava3.core.Maybe;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Supplier;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -113,17 +114,27 @@ public static class Builder {

Builder() {}

private <T> void addCallbackImpl(
Class<?> callbackClass, T callback, Supplier<BasePlugin> pluginSupplier) {
if (!callbacks.containsEntry(callbackClass, callback)) {
callbacks.put(callbackClass, callback);
plugins.add(pluginSupplier.get());
}
}

@CanIgnoreReturnValue
public Builder addBeforeAgentCallback(Callbacks.BeforeAgentCallback callback) {
callbacks.put(Callbacks.BeforeAgentCallback.class, callback);
plugins.add(
new BasePlugin("BeforeAgentCallback_" + callbackId.getAndIncrement()) {
@Override
public Maybe<Content> beforeAgentCallback(
BaseAgent agent, CallbackContext callbackContext) {
return callback.call(callbackContext);
}
});
addCallbackImpl(
Callbacks.BeforeAgentCallback.class,
callback,
() ->
new BasePlugin("BeforeAgentCallback_" + callbackId.getAndIncrement()) {
@Override
public Maybe<Content> beforeAgentCallback(
BaseAgent agent, CallbackContext callbackContext) {
return callback.call(callbackContext);
}
});
return this;
}

Expand All @@ -135,15 +146,17 @@ public Builder addBeforeAgentCallbackSync(Callbacks.BeforeAgentCallbackSync call

@CanIgnoreReturnValue
public Builder addAfterAgentCallback(Callbacks.AfterAgentCallback callback) {
callbacks.put(Callbacks.AfterAgentCallback.class, callback);
plugins.add(
new BasePlugin("AfterAgentCallback_" + callbackId.getAndIncrement()) {
@Override
public Maybe<Content> afterAgentCallback(
BaseAgent agent, CallbackContext callbackContext) {
return callback.call(callbackContext);
}
});
addCallbackImpl(
Callbacks.AfterAgentCallback.class,
callback,
() ->
new BasePlugin("AfterAgentCallback_" + callbackId.getAndIncrement()) {
@Override
public Maybe<Content> afterAgentCallback(
BaseAgent agent, CallbackContext callbackContext) {
return callback.call(callbackContext);
}
});
return this;
}

Expand All @@ -155,15 +168,17 @@ public Builder addAfterAgentCallbackSync(Callbacks.AfterAgentCallbackSync callba

@CanIgnoreReturnValue
public Builder addBeforeModelCallback(Callbacks.BeforeModelCallback callback) {
callbacks.put(Callbacks.BeforeModelCallback.class, callback);
plugins.add(
new BasePlugin("BeforeModelCallback_" + callbackId.getAndIncrement()) {
@Override
public Maybe<LlmResponse> beforeModelCallback(
CallbackContext callbackContext, LlmRequest.Builder llmRequest) {
return callback.call(callbackContext, llmRequest);
}
});
addCallbackImpl(
Callbacks.BeforeModelCallback.class,
callback,
() ->
new BasePlugin("BeforeModelCallback_" + callbackId.getAndIncrement()) {
@Override
public Maybe<LlmResponse> beforeModelCallback(
CallbackContext callbackContext, LlmRequest.Builder llmRequest) {
return callback.call(callbackContext, llmRequest);
}
});
return this;
}

Expand All @@ -174,31 +189,19 @@ public Builder addBeforeModelCallbackSync(Callbacks.BeforeModelCallbackSync call
Maybe.fromOptional(callback.call(callbackContext, llmRequest)));
}

// (b/476510024): Temporary workaround for ces
@CanIgnoreReturnValue
public Builder clearBeforeModelCallbacks() {
callbacks.removeAll(Callbacks.BeforeModelCallback.class);
ImmutableList.Builder<BasePlugin> updatedPlugins = ImmutableList.builder();
for (BasePlugin plugin : plugins.build()) {
if (!plugin.getName().startsWith("BeforeModelCallback_")) {
updatedPlugins.add(plugin);
}
}
plugins = updatedPlugins;
return this;
}

@CanIgnoreReturnValue
public Builder addAfterModelCallback(Callbacks.AfterModelCallback callback) {
callbacks.put(Callbacks.AfterModelCallback.class, callback);
plugins.add(
new BasePlugin("AfterModelCallback_" + callbackId.getAndIncrement()) {
@Override
public Maybe<LlmResponse> afterModelCallback(
CallbackContext callbackContext, LlmResponse llmResponse) {
return callback.call(callbackContext, llmResponse);
}
});
addCallbackImpl(
Callbacks.AfterModelCallback.class,
callback,
() ->
new BasePlugin("AfterModelCallback_" + callbackId.getAndIncrement()) {
@Override
public Maybe<LlmResponse> afterModelCallback(
CallbackContext callbackContext, LlmResponse llmResponse) {
return callback.call(callbackContext, llmResponse);
}
});
return this;
}

Expand All @@ -211,15 +214,18 @@ public Builder addAfterModelCallbackSync(Callbacks.AfterModelCallbackSync callba

@CanIgnoreReturnValue
public Builder addBeforeToolCallback(Callbacks.BeforeToolCallback callback) {
callbacks.put(Callbacks.BeforeToolCallback.class, callback);
plugins.add(
new BasePlugin("BeforeToolCallback_" + callbackId.getAndIncrement()) {
@Override
public Maybe<Map<String, Object>> beforeToolCallback(
BaseTool tool, Map<String, Object> toolArgs, ToolContext toolContext) {
return callback.call(toolContext.invocationContext(), tool, toolArgs, toolContext);
}
});
addCallbackImpl(
Callbacks.BeforeToolCallback.class,
callback,
() ->
new BasePlugin("BeforeToolCallback_" + callbackId.getAndIncrement()) {
@Override
public Maybe<Map<String, Object>> beforeToolCallback(
BaseTool tool, Map<String, Object> toolArgs, ToolContext toolContext) {
return callback.call(
toolContext.invocationContext(), tool, toolArgs, toolContext);
}
});
return this;
}

Expand All @@ -232,19 +238,21 @@ public Builder addBeforeToolCallbackSync(Callbacks.BeforeToolCallbackSync callba

@CanIgnoreReturnValue
public Builder addAfterToolCallback(Callbacks.AfterToolCallback callback) {
callbacks.put(Callbacks.AfterToolCallback.class, callback);
plugins.add(
new BasePlugin("AfterToolCallback_" + callbackId.getAndIncrement()) {
@Override
public Maybe<Map<String, Object>> afterToolCallback(
BaseTool tool,
Map<String, Object> toolArgs,
ToolContext toolContext,
Map<String, Object> result) {
return callback.call(
toolContext.invocationContext(), tool, toolArgs, toolContext, result);
}
});
addCallbackImpl(
Callbacks.AfterToolCallback.class,
callback,
() ->
new BasePlugin("AfterToolCallback_" + callbackId.getAndIncrement()) {
@Override
public Maybe<Map<String, Object>> afterToolCallback(
BaseTool tool,
Map<String, Object> toolArgs,
ToolContext toolContext,
Map<String, Object> result) {
return callback.call(
toolContext.invocationContext(), tool, toolArgs, toolContext, result);
}
});
return this;
}

Expand Down
7 changes: 0 additions & 7 deletions core/src/main/java/com/google/adk/agents/LlmAgent.java
Original file line number Diff line number Diff line change
Expand Up @@ -274,13 +274,6 @@ public Builder disallowTransferToPeers(boolean disallowTransferToPeers) {
return this;
}

// (b/476510024): Temporary workaround for ces
@CanIgnoreReturnValue
public Builder clearBeforeModelCallbacks() {
callbackPluginBuilder.clearBeforeModelCallbacks();
return this;
}

@CanIgnoreReturnValue
public Builder beforeModelCallback(BeforeModelCallback beforeModelCallback) {
callbackPluginBuilder.addBeforeModelCallback(beforeModelCallback);
Expand Down
13 changes: 13 additions & 0 deletions core/src/test/java/com/google/adk/agents/CallbackPluginTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -496,4 +496,17 @@ public void addMultipleAfterModelCallbacks_bothEmpty_returnsEmpty() {
assertThat(testCallback2.wasCalled()).isTrue();
assertThat(result).isNull();
}

@Test
public void addBeforeModelCallback_isIdempotent() {
LlmResponse expectedResponse = LlmResponse.builder().build();
var testCallback = TestCallback.returning(expectedResponse);
BeforeModelCallback callback = testCallback.asBeforeModelCallback();
CallbackPlugin plugin =
CallbackPlugin.builder()
.addBeforeModelCallback(callback)
.addBeforeModelCallback(callback)
.build();
assertThat(plugin.getBeforeModelCallback()).containsExactly(callback);
}
}