Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions core/src/main/java/com/google/adk/agents/RunConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ public enum ToolExecutionMode {

public abstract int maxLlmCalls();

public abstract boolean autoCreateSession();

public abstract Builder toBuilder();

public static Builder builder() {
Expand All @@ -78,7 +80,8 @@ public static Builder builder() {
.setResponseModalities(ImmutableList.of())
.setStreamingMode(StreamingMode.NONE)
.setToolExecutionMode(ToolExecutionMode.NONE)
.setMaxLlmCalls(500);
.setMaxLlmCalls(500)
.setAutoCreateSession(false);
}

public static Builder builder(RunConfig runConfig) {
Expand All @@ -90,7 +93,8 @@ public static Builder builder(RunConfig runConfig) {
.setResponseModalities(runConfig.responseModalities())
.setSpeechConfig(runConfig.speechConfig())
.setOutputAudioTranscription(runConfig.outputAudioTranscription())
.setInputAudioTranscription(runConfig.inputAudioTranscription());
.setInputAudioTranscription(runConfig.inputAudioTranscription())
.setAutoCreateSession(runConfig.autoCreateSession());
}

/** Builder for {@link RunConfig}. */
Expand Down Expand Up @@ -123,6 +127,9 @@ public abstract Builder setInputAudioTranscription(
@CanIgnoreReturnValue
public abstract Builder setMaxLlmCalls(int maxLlmCalls);

@CanIgnoreReturnValue
public abstract Builder setAutoCreateSession(boolean autoCreateSession);

abstract RunConfig autoBuild();

public RunConfig build() {
Expand Down
32 changes: 20 additions & 12 deletions core/src/main/java/com/google/adk/runner/Runner.java
Original file line number Diff line number Diff line change
Expand Up @@ -378,9 +378,15 @@ public Flowable<Event> runAsync(
this.sessionService.getSession(appName, userId, sessionId, Optional.empty());
return maybeSession
.switchIfEmpty(
Single.error(
new IllegalArgumentException(
String.format("Session not found: %s for user %s", sessionId, userId))))
Single.defer(
() -> {
if (runConfig.autoCreateSession()) {
return this.sessionService.createSession(appName, userId, null, sessionId);
}
return Single.error(
new IllegalArgumentException(
String.format("Session not found: %s for user %s", sessionId, userId)));
}))
.flatMapPublisher(session -> this.runAsync(session, newMessage, runConfig, stateDelta));
}

Expand Down Expand Up @@ -668,15 +674,17 @@ public Flowable<Event> runLive(
String userId, String sessionId, LiveRequestQueue liveRequestQueue, RunConfig runConfig) {
return this.sessionService
.getSession(appName, userId, sessionId, Optional.empty())
.flatMapPublisher(
session -> {
if (session == null) {
return Flowable.error(
new IllegalArgumentException(
String.format("Session not found: %s for user %s", sessionId, userId)));
}
return this.runLive(session, liveRequestQueue, runConfig);
});
.switchIfEmpty(
Single.defer(
() -> {
if (runConfig.autoCreateSession()) {
return this.sessionService.createSession(appName, userId, null, sessionId);
}
return Single.error(
new IllegalArgumentException(
String.format("Session not found: %s for user %s", sessionId, userId)));
}))
.flatMapPublisher(session -> this.runLive(session, liveRequestQueue, runConfig));
}

/**
Expand Down
76 changes: 72 additions & 4 deletions core/src/test/java/com/google/adk/runner/RunnerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import org.junit.After;
import org.junit.Before;
Expand All @@ -76,10 +77,8 @@ public final class RunnerTest {
private final Content pluginContent = createContent("from plugin");
private final TestLlm testLlm = createTestLlm(createLlmResponse(createContent("from llm")));
private final LlmAgent agent = createTestAgentBuilder(testLlm).build();
private final Runner runner =
Runner.builder().agent(agent).appName("test").plugins(ImmutableList.of(plugin)).build();
private final Session session =
runner.sessionService().createSession("test", "user").blockingGet();
private Runner runner;
private Session session;
private Tracer originalTracer;

private final FailingEchoTool failingEchoTool = new FailingEchoTool();
Expand Down Expand Up @@ -114,6 +113,9 @@ private BasePlugin mockPlugin(String name) {
public void setUp() {
this.originalTracer = Telemetry.getTracer();
Telemetry.setTracerForTesting(openTelemetryRule.getOpenTelemetry().getTracer("RunnerTest"));
this.runner =
Runner.builder().agent(agent).appName("test").plugins(ImmutableList.of(plugin)).build();
this.session = runner.sessionService().createSession("test", "user").blockingGet();
}

@After
Expand Down Expand Up @@ -832,4 +834,70 @@ public void resumabilityConfig_isNotResumable_isFalseInInvocationContext() {
runner.runAsync("user", session.id(), createContent("from user")).toList().blockingGet();
assertThat(contextCaptor.getValue().isResumable()).isFalse();
}

@Test
public void runAsync_withoutSessionAndAutoCreateSessionTrue_createsSession() {
RunConfig runConfig = RunConfig.builder().setAutoCreateSession(true).build();
String newSessionId = UUID.randomUUID().toString();

var events =
runner
.runAsync("user", newSessionId, createContent("from user"), runConfig)
.toList()
.blockingGet();

assertThat(simplifyEvents(events)).containsExactly("test agent: from llm");
assertThat(
runner
.sessionService()
.getSession("test", "user", newSessionId, Optional.empty())
.blockingGet())
.isNotNull();
}

@Test
public void runAsync_withoutSessionAndAutoCreateSessionFalse_throwsException() {
RunConfig runConfig = RunConfig.builder().setAutoCreateSession(false).build();
String newSessionId = UUID.randomUUID().toString();

runner
.runAsync("user", newSessionId, createContent("from user"), runConfig)
.test()
.assertError(IllegalArgumentException.class);
}

@Test
public void runLive_withoutSessionAndAutoCreateSessionTrue_createsSession() throws Exception {
RunConfig runConfig = RunConfig.builder().setAutoCreateSession(true).build();
String newSessionId = UUID.randomUUID().toString();
LiveRequestQueue liveRequestQueue = new LiveRequestQueue();

TestSubscriber<Event> testSubscriber =
runner.runLive("user", newSessionId, liveRequestQueue, runConfig).test();

liveRequestQueue.content(createContent("from user"));
liveRequestQueue.close();

testSubscriber.await();
testSubscriber.assertComplete();
assertThat(simplifyEvents(testSubscriber.values())).containsExactly("test agent: from llm");
assertThat(
runner
.sessionService()
.getSession("test", "user", newSessionId, Optional.empty())
.blockingGet())
.isNotNull();
}

@Test
public void runLive_withoutSessionAndAutoCreateSessionFalse_throwsException() {
RunConfig runConfig = RunConfig.builder().setAutoCreateSession(false).build();
String newSessionId = UUID.randomUUID().toString();
LiveRequestQueue liveRequestQueue = new LiveRequestQueue();

runner
.runLive("user", newSessionId, liveRequestQueue, runConfig)
.test()
.assertError(IllegalArgumentException.class);
}
}