Skip to content

Commit d76353d

Browse files
google-genai-botcopybara-github
authored andcommitted
feat: Adding autoCreateSession in Runner
PiperOrigin-RevId: 858855558
1 parent 5d0e0ae commit d76353d

3 files changed

Lines changed: 111 additions & 14 deletions

File tree

core/src/main/java/com/google/adk/agents/RunConfig.java

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@ public enum ToolExecutionMode {
7070

7171
public abstract int maxLlmCalls();
7272

73+
public abstract boolean autoCreateSession();
74+
7375
public abstract Builder toBuilder();
7476

7577
public static Builder builder() {
@@ -78,7 +80,8 @@ public static Builder builder() {
7880
.setResponseModalities(ImmutableList.of())
7981
.setStreamingMode(StreamingMode.NONE)
8082
.setToolExecutionMode(ToolExecutionMode.NONE)
81-
.setMaxLlmCalls(500);
83+
.setMaxLlmCalls(500)
84+
.setAutoCreateSession(false);
8285
}
8386

8487
public static Builder builder(RunConfig runConfig) {
@@ -90,7 +93,8 @@ public static Builder builder(RunConfig runConfig) {
9093
.setResponseModalities(runConfig.responseModalities())
9194
.setSpeechConfig(runConfig.speechConfig())
9295
.setOutputAudioTranscription(runConfig.outputAudioTranscription())
93-
.setInputAudioTranscription(runConfig.inputAudioTranscription());
96+
.setInputAudioTranscription(runConfig.inputAudioTranscription())
97+
.setAutoCreateSession(runConfig.autoCreateSession());
9498
}
9599

96100
/** Builder for {@link RunConfig}. */
@@ -123,6 +127,9 @@ public abstract Builder setInputAudioTranscription(
123127
@CanIgnoreReturnValue
124128
public abstract Builder setMaxLlmCalls(int maxLlmCalls);
125129

130+
@CanIgnoreReturnValue
131+
public abstract Builder setAutoCreateSession(boolean autoCreateSession);
132+
126133
abstract RunConfig autoBuild();
127134

128135
public RunConfig build() {

core/src/main/java/com/google/adk/runner/Runner.java

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -378,9 +378,15 @@ public Flowable<Event> runAsync(
378378
this.sessionService.getSession(appName, userId, sessionId, Optional.empty());
379379
return maybeSession
380380
.switchIfEmpty(
381-
Single.error(
382-
new IllegalArgumentException(
383-
String.format("Session not found: %s for user %s", sessionId, userId))))
381+
Single.defer(
382+
() -> {
383+
if (runConfig.autoCreateSession()) {
384+
return this.sessionService.createSession(appName, userId, null, sessionId);
385+
}
386+
return Single.error(
387+
new IllegalArgumentException(
388+
String.format("Session not found: %s for user %s", sessionId, userId)));
389+
}))
384390
.flatMapPublisher(session -> this.runAsync(session, newMessage, runConfig, stateDelta));
385391
}
386392

@@ -668,15 +674,17 @@ public Flowable<Event> runLive(
668674
String userId, String sessionId, LiveRequestQueue liveRequestQueue, RunConfig runConfig) {
669675
return this.sessionService
670676
.getSession(appName, userId, sessionId, Optional.empty())
671-
.flatMapPublisher(
672-
session -> {
673-
if (session == null) {
674-
return Flowable.error(
675-
new IllegalArgumentException(
676-
String.format("Session not found: %s for user %s", sessionId, userId)));
677-
}
678-
return this.runLive(session, liveRequestQueue, runConfig);
679-
});
677+
.switchIfEmpty(
678+
Single.defer(
679+
() -> {
680+
if (runConfig.autoCreateSession()) {
681+
return this.sessionService.createSession(appName, userId, null, sessionId);
682+
}
683+
return Single.error(
684+
new IllegalArgumentException(
685+
String.format("Session not found: %s for user %s", sessionId, userId)));
686+
}))
687+
.flatMapPublisher(session -> this.runLive(session, liveRequestQueue, runConfig));
680688
}
681689

682690
/**

core/src/test/java/com/google/adk/runner/RunnerTest.java

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -832,4 +832,86 @@ public void resumabilityConfig_isNotResumable_isFalseInInvocationContext() {
832832
runner.runAsync("user", session.id(), createContent("from user")).toList().blockingGet();
833833
assertThat(contextCaptor.getValue().isResumable()).isFalse();
834834
}
835+
836+
@Test
837+
public void runAsync_withoutSessionAndAutoCreateSessionTrue_createsSession() {
838+
RunConfig runConfig = RunConfig.builder().setAutoCreateSession(true).build();
839+
String newSessionId = "new-session-id";
840+
841+
// Make sure session doesn't exist.
842+
assertThat(
843+
runner
844+
.sessionService()
845+
.getSession("test", "user", newSessionId, Optional.empty())
846+
.blockingGet())
847+
.isNull();
848+
849+
var events =
850+
runner
851+
.runAsync("user", newSessionId, createContent("from user"), runConfig)
852+
.toList()
853+
.blockingGet();
854+
855+
assertThat(simplifyEvents(events)).containsExactly("test agent: from llm");
856+
assertThat(
857+
runner
858+
.sessionService()
859+
.getSession("test", "user", newSessionId, Optional.empty())
860+
.blockingGet())
861+
.isNotNull();
862+
}
863+
864+
@Test
865+
public void runAsync_withoutSessionAndAutoCreateSessionFalse_throwsException() {
866+
RunConfig runConfig = RunConfig.builder().setAutoCreateSession(false).build();
867+
String newSessionId = "new-session-id";
868+
869+
runner
870+
.runAsync("user", newSessionId, createContent("from user"), runConfig)
871+
.test()
872+
.assertError(IllegalArgumentException.class);
873+
}
874+
875+
@Test
876+
public void runLive_withoutSessionAndAutoCreateSessionTrue_createsSession() throws Exception {
877+
RunConfig runConfig = RunConfig.builder().setAutoCreateSession(true).build();
878+
String newSessionId = "new-session-id";
879+
LiveRequestQueue liveRequestQueue = new LiveRequestQueue();
880+
881+
// Make sure session doesn't exist.
882+
assertThat(
883+
runner
884+
.sessionService()
885+
.getSession("test", "user", newSessionId, Optional.empty())
886+
.blockingGet())
887+
.isNull();
888+
889+
TestSubscriber<Event> testSubscriber =
890+
runner.runLive("user", newSessionId, liveRequestQueue, runConfig).test();
891+
892+
liveRequestQueue.content(createContent("from user"));
893+
liveRequestQueue.close();
894+
895+
testSubscriber.await();
896+
testSubscriber.assertComplete();
897+
assertThat(simplifyEvents(testSubscriber.values())).containsExactly("test agent: from llm");
898+
assertThat(
899+
runner
900+
.sessionService()
901+
.getSession("test", "user", newSessionId, Optional.empty())
902+
.blockingGet())
903+
.isNotNull();
904+
}
905+
906+
@Test
907+
public void runLive_withoutSessionAndAutoCreateSessionFalse_throwsException() {
908+
RunConfig runConfig = RunConfig.builder().setAutoCreateSession(false).build();
909+
String newSessionId = "new-session-id";
910+
LiveRequestQueue liveRequestQueue = new LiveRequestQueue();
911+
912+
runner
913+
.runLive("user", newSessionId, liveRequestQueue, runConfig)
914+
.test()
915+
.assertError(IllegalArgumentException.class);
916+
}
835917
}

0 commit comments

Comments
 (0)