diff --git a/core/src/main/java/com/google/adk/agents/RunConfig.java b/core/src/main/java/com/google/adk/agents/RunConfig.java index 2f8e417d..308169e3 100644 --- a/core/src/main/java/com/google/adk/agents/RunConfig.java +++ b/core/src/main/java/com/google/adk/agents/RunConfig.java @@ -70,6 +70,8 @@ public enum ToolExecutionMode { public abstract int maxLlmCalls(); + public abstract boolean autoCreateSession(); + public abstract Builder toBuilder(); public static Builder builder() { @@ -78,7 +80,8 @@ public static Builder builder() { .setResponseModalities(ImmutableList.of()) .setStreamingMode(StreamingMode.NONE) .setToolExecutionMode(ToolExecutionMode.NONE) - .setMaxLlmCalls(500); + .setMaxLlmCalls(500) + .setAutoCreateSession(false); } public static Builder builder(RunConfig runConfig) { @@ -90,7 +93,8 @@ public static Builder builder(RunConfig runConfig) { .setResponseModalities(runConfig.responseModalities()) .setSpeechConfig(runConfig.speechConfig()) .setOutputAudioTranscription(runConfig.outputAudioTranscription()) - .setInputAudioTranscription(runConfig.inputAudioTranscription()); + .setInputAudioTranscription(runConfig.inputAudioTranscription()) + .setAutoCreateSession(runConfig.autoCreateSession()); } /** Builder for {@link RunConfig}. */ @@ -123,6 +127,9 @@ public abstract Builder setInputAudioTranscription( @CanIgnoreReturnValue public abstract Builder setMaxLlmCalls(int maxLlmCalls); + @CanIgnoreReturnValue + public abstract Builder setAutoCreateSession(boolean autoCreateSession); + abstract RunConfig autoBuild(); public RunConfig build() { diff --git a/core/src/main/java/com/google/adk/runner/Runner.java b/core/src/main/java/com/google/adk/runner/Runner.java index e047d1e6..c5c18cdf 100644 --- a/core/src/main/java/com/google/adk/runner/Runner.java +++ b/core/src/main/java/com/google/adk/runner/Runner.java @@ -378,9 +378,15 @@ public Flowable runAsync( this.sessionService.getSession(appName, userId, sessionId, Optional.empty()); return maybeSession .switchIfEmpty( - Single.error( - new IllegalArgumentException( - String.format("Session not found: %s for user %s", sessionId, userId)))) + Single.defer( + () -> { + if (runConfig.autoCreateSession()) { + return this.sessionService.createSession(appName, userId, null, sessionId); + } + return Single.error( + new IllegalArgumentException( + String.format("Session not found: %s for user %s", sessionId, userId))); + })) .flatMapPublisher(session -> this.runAsync(session, newMessage, runConfig, stateDelta)); } @@ -668,15 +674,17 @@ public Flowable runLive( String userId, String sessionId, LiveRequestQueue liveRequestQueue, RunConfig runConfig) { return this.sessionService .getSession(appName, userId, sessionId, Optional.empty()) - .flatMapPublisher( - session -> { - if (session == null) { - return Flowable.error( - new IllegalArgumentException( - String.format("Session not found: %s for user %s", sessionId, userId))); - } - return this.runLive(session, liveRequestQueue, runConfig); - }); + .switchIfEmpty( + Single.defer( + () -> { + if (runConfig.autoCreateSession()) { + return this.sessionService.createSession(appName, userId, null, sessionId); + } + return Single.error( + new IllegalArgumentException( + String.format("Session not found: %s for user %s", sessionId, userId))); + })) + .flatMapPublisher(session -> this.runLive(session, liveRequestQueue, runConfig)); } /** diff --git a/core/src/test/java/com/google/adk/runner/RunnerTest.java b/core/src/test/java/com/google/adk/runner/RunnerTest.java index 994acca2..f1a01100 100644 --- a/core/src/test/java/com/google/adk/runner/RunnerTest.java +++ b/core/src/test/java/com/google/adk/runner/RunnerTest.java @@ -59,6 +59,7 @@ import java.util.List; import java.util.Objects; import java.util.Optional; +import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; import org.junit.After; import org.junit.Before; @@ -76,10 +77,8 @@ public final class RunnerTest { private final Content pluginContent = createContent("from plugin"); private final TestLlm testLlm = createTestLlm(createLlmResponse(createContent("from llm"))); private final LlmAgent agent = createTestAgentBuilder(testLlm).build(); - private final Runner runner = - Runner.builder().agent(agent).appName("test").plugins(ImmutableList.of(plugin)).build(); - private final Session session = - runner.sessionService().createSession("test", "user").blockingGet(); + private Runner runner; + private Session session; private Tracer originalTracer; private final FailingEchoTool failingEchoTool = new FailingEchoTool(); @@ -114,6 +113,9 @@ private BasePlugin mockPlugin(String name) { public void setUp() { this.originalTracer = Telemetry.getTracer(); Telemetry.setTracerForTesting(openTelemetryRule.getOpenTelemetry().getTracer("RunnerTest")); + this.runner = + Runner.builder().agent(agent).appName("test").plugins(ImmutableList.of(plugin)).build(); + this.session = runner.sessionService().createSession("test", "user").blockingGet(); } @After @@ -832,4 +834,70 @@ public void resumabilityConfig_isNotResumable_isFalseInInvocationContext() { runner.runAsync("user", session.id(), createContent("from user")).toList().blockingGet(); assertThat(contextCaptor.getValue().isResumable()).isFalse(); } + + @Test + public void runAsync_withoutSessionAndAutoCreateSessionTrue_createsSession() { + RunConfig runConfig = RunConfig.builder().setAutoCreateSession(true).build(); + String newSessionId = UUID.randomUUID().toString(); + + var events = + runner + .runAsync("user", newSessionId, createContent("from user"), runConfig) + .toList() + .blockingGet(); + + assertThat(simplifyEvents(events)).containsExactly("test agent: from llm"); + assertThat( + runner + .sessionService() + .getSession("test", "user", newSessionId, Optional.empty()) + .blockingGet()) + .isNotNull(); + } + + @Test + public void runAsync_withoutSessionAndAutoCreateSessionFalse_throwsException() { + RunConfig runConfig = RunConfig.builder().setAutoCreateSession(false).build(); + String newSessionId = UUID.randomUUID().toString(); + + runner + .runAsync("user", newSessionId, createContent("from user"), runConfig) + .test() + .assertError(IllegalArgumentException.class); + } + + @Test + public void runLive_withoutSessionAndAutoCreateSessionTrue_createsSession() throws Exception { + RunConfig runConfig = RunConfig.builder().setAutoCreateSession(true).build(); + String newSessionId = UUID.randomUUID().toString(); + LiveRequestQueue liveRequestQueue = new LiveRequestQueue(); + + TestSubscriber testSubscriber = + runner.runLive("user", newSessionId, liveRequestQueue, runConfig).test(); + + liveRequestQueue.content(createContent("from user")); + liveRequestQueue.close(); + + testSubscriber.await(); + testSubscriber.assertComplete(); + assertThat(simplifyEvents(testSubscriber.values())).containsExactly("test agent: from llm"); + assertThat( + runner + .sessionService() + .getSession("test", "user", newSessionId, Optional.empty()) + .blockingGet()) + .isNotNull(); + } + + @Test + public void runLive_withoutSessionAndAutoCreateSessionFalse_throwsException() { + RunConfig runConfig = RunConfig.builder().setAutoCreateSession(false).build(); + String newSessionId = UUID.randomUUID().toString(); + LiveRequestQueue liveRequestQueue = new LiveRequestQueue(); + + runner + .runLive("user", newSessionId, liveRequestQueue, runConfig) + .test() + .assertError(IllegalArgumentException.class); + } }